我有三个Tensor,大小为a
(n,4,m),b
,具有大小(n,4,1),c
与大小(n,k,m). a
包含我想获取的一批结点特征,b
表示有效结点索引(无效索引用999屏蔽),c
为准确设置的节点特征批次。n和k分别为批次数和节点数。
目标是将a
的有效结点要素替换为c
的对应结点要素,索引值按b
排序
到目前为止,我使用嵌套的for循环来实现它
import torch
n=4
a = torch.arange(n*4*4).view(n,4,4)
value_c = torch.zeros(n,6,4)
b=torch.randint(0,3,(n,4,1))
b[0,1:]=999
b[2,2:] = 999
for i in range(n):
for j in range(4):
if b[i,j]<999:
a[i,j]=value_c[i,b[i,j].long()]
但是对于一个大数据集来说它确实很慢。有什么方法可以加速它吗(eidogg.使用逻辑索引)?
1条答案
按热度按时间t40tm48m1#
当然,对于初学者,我建议使用
nan
或inf
或一些特殊值来屏蔽无效索引;使用一个特定的整数只会在你增加数据大小时引起一些难以捕捉的问题。这确实给予b
的类型变得复杂(这样它就可以存储nan
的值),而且你必须在使用它来索引之前将它转换为long
。我个人认为这是值得的,但你可以随心所欲。要使用列表式索引,我们需要将
a
的索引分解为a
的每个维度的1DTensor。现在
i
、j
和k
每个都包含nx4
索引,这大约是要替换的元素数。现在让我们重新索引每个Tensor一次,以删除所有无效索引。现在我们准备索引。
您可能需要执行一些类型转换来使所有内容都正常工作(例如,
a
和c
需要具有相同的类型)。