numpy 获取由另一个2D ndarray中的index'定义的2D ndarray的列

h6my8fg2  于 2023-03-23  发布在  其他
关注(0)|答案(1)|浏览(178)

我有一个ndarray arr = np.array([[1,2,3],[4,5,6],[7,8,9]])和一个index-array
arr_idx = np.array([[0,2],[1,2],[2,1]])其中arr_idx中的每一行对应于我想要的arr的索引,即结果应该是[[1,3],[5,6],[9,8]]
我可以使用例如listcomprehension来完成它,但是我有一些相当大的数据,因此如果我们可以将其向量化,那会更好。
我试过了

result = arr[arr_idx]

这导致了

array([[[1, 2, 3],
        [7, 8, 9]],
       [[4, 5, 6],
        [7, 8, 9]],
       [[7, 8, 9],
        [4, 5, 6]]])

应该是array([[1,3],[5,6],[9,8]])

deikduxw

deikduxw1#

您需要用途:

result = arr[np.arange(arr.shape[0])[:,None], arr_idx]

或者:

result = np.take_along_axis(arr, arr_idx, 1)

输出:

array([[1, 3],
       [5, 6],
       [9, 8]])

相关问题