PyTorch中的groupby聚合产品

hyrbngr7  于 2022-12-04  发布在  其他
关注(0)|答案(1)|浏览(161)

我遇到了与groupby aggregate mean in pytorch相同的问题。但是,我想在每个组(或标签)中创建Tensor的乘积。不幸的是,我找不到一个本地PyTorch函数来解决我的问题,比如一个假设的scatter_prod_(等价于scatter_add_),它是answers中使用的函数。
再循环@elyase问题的示例代码,考虑2DTensor:

samples = torch.Tensor([
    [0.1, 0.1],    #-> group / class 1
    [0.2, 0.2],    #-> group / class 2
    [0.4, 0.4],    #-> group / class 2
    [0.0, 0.0]     #-> group / class 0
])

其中len(samples) == len(labels)为真

labels = torch.LongTensor([1, 2, 2, 0])

因此,我的预期输出为:

res == torch.Tensor([
    [0.0, 0.0],
    [0.1, 0.1], 
    [0.8, 0.8] # -> PRODUCT of [0.2, 0.2] and [0.4, 0.4]
])

这里的问题是,再一次,在@elyase的问题之后,如何在pure PyTorch(也就是说,没有numpy,这样我就可以autograd)中完成这一点,并且在理想情况下没有for循环?
PyTorch forums中交叉过帐。

ffx8fchx

ffx8fchx1#

您可以使用scatter_函数来计算每个组中Tensor的乘积。

samples = torch.Tensor([
    [0.1, 0.1],    #-> group / class 1
    [0.2, 0.2],    #-> group / class 2
    [0.4, 0.4],    #-> group / class 2
    [0.0, 0.0]     #-> group / class 0
])

labels = torch.LongTensor([1,2,2,0])

label_size = 3
sample_dim = samples.size(1)

index = labels.unsqueeze(1).repeat((1, sample_dim))

res = torch.ones(label_size, sample_dim, dtype=samples.dtype)
res.scatter_(0, index, samples, reduce='multiply')

res

tensor([[0.0000, 0.0000],
        [0.1000, 0.1000],
        [0.0800, 0.0800]])

相关问题