Numpy“put_沿着_axis”,但添加到现有值而不仅仅是put(类似于PyTorch中的scatter_add)

bvk5enib  于 2023-05-07  发布在  其他
关注(0)|答案(1)|浏览(191)

有没有一种方法可以使用np.put_along_axis,但将其添加到现有值而不是替换?
例如,在PyTorch中,这可以实现为:

import torch
frame = torch.zeros(3,2, dtype=torch.double)
updates = torch.tensor([[5,5], [10,10], [3,3]], dtype=torch.double)
indices = torch.tensor([[1,1], [1,1], [2,2]])

frame.scatter_add(0, indices, updates)
OUTPUT: [[0, 0], [15,15], [3,3]]

Numpy的put_along_axis将给予:

import numpy as np
frame = np.zeros(3,2)
updates = np.array([[5,5], [10,10], [3,3]])
indices = np.array([[1,1], [1,1], [2,2]])

np.put_along_axis(frame, indices, update)
OUTPUT: [[0, 0],[10, 10], [3,3]]
xjreopfe

xjreopfe1#

您可以使用numpy.add.at,但需要构建额外的列索引:

>>> frame = np.zeros((3, 2))
>>> updates = np.array([[5,5], [10,10], [3,3]])
>>> indices = np.array([[1,1], [1,1], [2,2]])
>>> np.add.at(frame, (indices, np.arange(frame.shape[1])), updates)
>>> frame
array([[ 0.,  0.],
       [15., 15.],
       [ 3.,  3.]])

相关问题