如何使用numpy.einsum添加冗余索引

rkkpypqq  于 2023-04-06  发布在  其他
关注(0)|答案(1)|浏览(177)

假设我有一个N x N x N维的numpy数组X,其中的条目为X[i,j,k]。我想使用X来定义一个N x N x N x N维的numpy数组Y,定义如下:

Y[i,j,k,k] = X[i,j,k]
Y[i,j,k,l] = 0 when k != l

我的想法是使用numpy.einsum通过以下代码完成此任务:

Y = np.einsum('ijk->ijkk', X).

但是,这不起作用,因为我得到了以下错误

ValueError: einstein sum subscripts string includes output subscript 'k' multiple times

有没有一种方法可以直接完成这个任务,而不必使用for循环?

pdsfdshx

pdsfdshx1#

你可以使用einsum来提取对角线。你会得到一个原始数组的视图,这意味着如果你修改视图,你也会修改原始数组。这可以让你覆盖原始数组的条目,就像这样:

# setup
import numpy as np
X = np.arange(0, 3*4*5).reshape((3, 4, 5))  # bogus data
Y = np.zeros((3, 4, 5, 5))                  # initialize Y with zeros

# extract diagonal, and write into it
Y_diag_view = np.einsum('ijkk->ijk', Y)
Y_diag_view[...] = X                        # modify Y_diag_view and therefore X in place
print(Y[0, 0, ...])                         # prints diagonal matrix

相关问题