numpy 仅沿着第二维使用np.diagflat?

bfhwhh0e  于 2023-05-17  发布在  其他
关注(0)|答案(2)|浏览(103)

我有一个形状为(5094, 512)的数组(x_train),我想把第二维的1d切片变成一个2d对角数组。我正在尝试下面的代码:

x_train_diag = np.zeros((5094, 512, 512))
x_train_diag = np.apply_along_axis(np.diagflat, x_train, 1)

举个例子,如果我有5094批[1,1,...,1]数据数组;相反,我想有5094批

[1 0 0
 0 1 0
 ...
 0 0 1]

我得到了这个模糊的TypeError:

TypeError: only integer scalar arrays can be converted to a scalar index
agyaoht7

agyaoht71#

你可以将numpy的高级索引与广播一起使用。下面是一个小例子:

import numpy as np

# Your inputs here
x_train = np.arange(1, 13).reshape(4, 3)

# Solution:
H, W = x_train.shape
x_train_diag = np.zeros((H, W, W), dtype=x_train.dtype)
rows, cols = np.indices((H, W), sparse=True)

x_train_diag[rows, cols, cols] = x_train

结果:

>>> x_train
array([[ 1,  2,  3],
       [ 4,  5,  6],
       [ 7,  8,  9],
       [10, 11, 12]])
>>> x_train_diag
array([[[ 1,  0,  0],
        [ 0,  2,  0],
        [ 0,  0,  3]],

       [[ 4,  0,  0],
        [ 0,  5,  0],
        [ 0,  0,  6]],

       [[ 7,  0,  0],
        [ 0,  8,  0],
        [ 0,  0,  9]],

       [[10,  0,  0],
        [ 0, 11,  0],
        [ 0,  0, 12]]])
qyyhg6bp

qyyhg6bp2#

这样的东西对你有用吗?我在对角矩阵的列表上使用了堆栈。

# generating the data as a 2D matrix
nb_batches = 5
nb_data = 7
a = np.random.randint(0, 9, (nb_batches, nb_data))
print(a)
# turning it into a 3D matrix (list of diag matrices)
new_a = np.stack([np.diag(a[i,:]) for i in range(nb_batches)])
print()
print(new_a.shape)
print(new_a[0])

结果:

[[1 7 2 5 1 2 8]
 [5 8 6 2 4 6 6]
 [5 2 8 5 2 8 5]
 [3 3 0 6 5 3 3]
 [1 0 8 8 4 0 5]]

(5, 7, 7)
[[1 0 0 0 0 0 0]
 [0 7 0 0 0 0 0]
 [0 0 2 0 0 0 0]
 [0 0 0 5 0 0 0]
 [0 0 0 0 1 0 0]
 [0 0 0 0 0 2 0]
 [0 0 0 0 0 0 8]]

相关问题