我遇到了与groupby aggregate mean in pytorch相同的问题。但是,我想在每个组(或标签)中创建Tensor的乘积。不幸的是,我找不到一个本地PyTorch函数来解决我的问题,比如一个假设的scatter_prod_
(等价于scatter_add_
),它是answers中使用的函数。
再循环@elyase问题的示例代码,考虑2DTensor:
samples = torch.Tensor([
[0.1, 0.1], #-> group / class 1
[0.2, 0.2], #-> group / class 2
[0.4, 0.4], #-> group / class 2
[0.0, 0.0] #-> group / class 0
])
其中len(samples) == len(labels)
为真
labels = torch.LongTensor([1, 2, 2, 0])
因此,我的预期输出为:
res == torch.Tensor([
[0.0, 0.0],
[0.1, 0.1],
[0.8, 0.8] # -> PRODUCT of [0.2, 0.2] and [0.4, 0.4]
])
这里的问题是,再一次,在@elyase的问题之后,如何在pure PyTorch(也就是说,没有numpy,这样我就可以autograd)中完成这一点,并且在理想情况下没有for循环?
在PyTorch forums中交叉过帐。
1条答案
按热度按时间ffx8fchx1#
您可以使用
scatter_
函数来计算每个组中Tensor的乘积。res
: