python 3DTensor中的Pytorch逻辑索引

xpcnnkqh  于 2022-12-17  发布在  Python
关注(0)|答案(1)|浏览(116)

我有三个Tensor,大小为a(n,4,m),b,具有大小(n,4,1),c与大小(n,k,m). a包含我想获取的一批结点特征,b表示有效结点索引(无效索引用999屏蔽),c为准确设置的节点特征批次。nk分别为批次数和节点数。
目标是将a的有效结点要素替换为c的对应结点要素,索引值按b排序
到目前为止,我使用嵌套的for循环来实现它

import torch
n=4
a = torch.arange(n*4*4).view(n,4,4)
value_c = torch.zeros(n,6,4)
b=torch.randint(0,3,(n,4,1))
b[0,1:]=999
b[2,2:] = 999
for i in range(n):
    for j in range(4):
        if b[i,j]<999:
            a[i,j]=value_c[i,b[i,j].long()]

但是对于一个大数据集来说它确实很慢。有什么方法可以加速它吗(eidogg.使用逻辑索引)?

t40tm48m

t40tm48m1#

当然,对于初学者,我建议使用naninf或一些特殊值来屏蔽无效索引;使用一个特定的整数只会在你增加数据大小时引起一些难以捕捉的问题。这确实给予b的类型变得复杂(这样它就可以存储nan的值),而且你必须在使用它来索引之前将它转换为long。我个人认为这是值得的,但你可以随心所欲。
要使用列表式索引,我们需要将a的索引分解为a的每个维度的1DTensor。

i = torch.arange(n).unsqueeze(1).expand(n,4).reshape(-1)  # something like [0,0,0,0,1,1,1,1 ...]
j = torch.arange(4).unsqueeze(0).expand(n,4).reshape(-1)  # something like [0,1,2,3,0,1,2,3,...]

k = b[i,j].squeeze() # assemble desired indices of c into 1D tensor as well

现在ijk每个都包含nx4索引,这大约是要替换的元素数。现在让我们重新索引每个Tensor一次,以删除所有无效索引。

valid = torch.where(torch.isnan(k),0,1).nonzero().squeeze().long()
k = k[valid]
i = i[valid]
j = j[valid]

现在我们准备索引。

a[i,j,:] = c[i,k,:]

您可能需要执行一些类型转换来使所有内容都正常工作(例如,ac需要具有相同的类型)。

相关问题