用向量填充3D numpy数组中每个矩阵的对角线

hiz5n14c  于 2023-10-19  发布在  其他
关注(0)|答案(4)|浏览(95)

有一个3D numpy数组,其中每个2D切片代表一个单独的矩阵。我想用一组特定的值来替换每个矩阵的对角元素。
例如,如果我有一个3x3x3数组:

array([[[a1, a2, a3],
        [a4, a5, a6],
        [a7, a8, a9]],

       [[b1, b2, b3],
        [b4, b5, b6],
        [b7, b8, b9]],

       [[c1, c2, c3],
        [c4, c5, c6],
        [c7, c8, c9]]])

我想用每个矩阵的一组新值替换对角线[a1, a5, a9][b1, b5, b9][c1, c5, c9]。我如何才能做到这一点?

oaxa6hgo

oaxa6hgo1#

使用integer index:

import numpy as np

# Setup:
arr = np.zeros((3, 4, 6), dtype=int)
vectors = np.random.randint(1, 9, size=(3, 4))

# Should work for arbitrary `arr` with ndim >= 2
n = min(arr.shape[-2:])
idx = np.arange(n)

# Note that `vectors.shape` must broadcast with
# `(*arr.shape[:-2], n)` for this to work:

arr[..., idx, idx] = vectors
7gs2gvoe

7gs2gvoe2#

可以使用np.eye掩码。

zeros = np.zeros((3,4,6))
vectors = np.random.randint(1 , 9, (3, 4)) # Generating random 3x4 integer array between 1 and 9

mask = np.eye(zeros.shape[1], zeros.shape[2], dtype=bool)
zeros[:, mask] = vectors

下面是print(zeros)的输出。

[[[2. 0. 0. 0. 0. 0.]
  [0. 4. 0. 0. 0. 0.]
  [0. 0. 7. 0. 0. 0.]
  [0. 0. 0. 4. 0. 0.]]

 [[7. 0. 0. 0. 0. 0.]
  [0. 4. 0. 0. 0. 0.]
  [0. 0. 3. 0. 0. 0.]
  [0. 0. 0. 6. 0. 0.]]

 [[3. 0. 0. 0. 0. 0.]
  [0. 5. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0.]
  [0. 0. 0. 3. 0. 0.]]]
zynd9foi

zynd9foi3#

您可以手工创建索引以进行直接索引:

# example input
a = np.zeros((3,3,3), dtype=int)

N = min(a.shape)
idx1 = np.repeat(range(N), N)
idx2 = np.tile(range(N), N)

a[idx1, idx2, idx2] = np.arange(1, 10)

输出量:

array([[[1, 0, 0],
        [0, 2, 0],
        [0, 0, 3]],

       [[4, 0, 0],
        [0, 5, 0],
        [0, 0, 6]],

       [[7, 0, 0],
        [0, 8, 0],
        [0, 0, 9]]])

中间体:

# idx1
array([0, 0, 0, 1, 1, 1, 2, 2, 2])

# idx2
array([0, 1, 2, 0, 1, 2, 0, 1, 2])
lg40wkob

lg40wkob4#

将来,当numpy实现下面描述的功能时,这将是一行程序。在此之前,您需要使用其他答案中所示的索引或掩码方法。
numpy.diagonal的文档(从v1.25开始)说:
从NumPy 1.9开始,它返回原始数组的只读视图。尝试写入结果数组将产生错误。
在未来的版本中,它将返回一个读/写视图,写入返回的数组将改变您的原始数组。返回的数组将具有与输入数组相同的类型。
在numpy的未来版本中,您应该能够直接写入np.diagonal返回的视图。

np.diagonal(zeros, axis1=-2, axis2=-1) = vectors

相关问题