我在我的基于PyTorch的代码中有以下实现,它涉及到一个嵌套的for循环。嵌套的for循环沿着if
条件使代码执行起来非常慢。我试图避免嵌套循环,以涉及NumPy和PyTorch中的广播概念,但没有产生任何结果。任何关于避免for
循环的帮助将不胜感激。
下面是我读过的链接PyTorch,NumPy。
#!/usr/bin/env python
# coding: utf-8
import torch
batch_size=32
mask=torch.FloatTensor(batch_size).uniform_() > 0.8
teacher_count=510
student_count=420
feature_dim=750
student_output=torch.zeros([batch_size,student_count])
teacher_output=torch.zeros([batch_size,teacher_count])
student_adjacency_mat=torch.randint(0,1,(student_count,student_count))
teacher_adjacency_mat=torch.randint(0,1,(teacher_count,teacher_count))
student_feat=torch.rand([batch_size,feature_dim])
student_graph=torch.rand([student_count,feature_dim])
teacher_feat=torch.rand([batch_size,feature_dim])
teacher_graph=torch.rand([teacher_count,feature_dim])
for m in range(batch_size):
if mask[m]==1:
for i in range(student_count):
for j in range(student_count):
student_output[m][i]=student_output[m][i]+student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])
if mask[m]==0:
for i in range(teacher_count):
for j in range(teacher_count):
teacher_output[m][i]=teacher_output[m][i]+teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])
1条答案
按热度按时间oxcyiej71#
问题陈述
您要执行的操作非常简单。如果您仔细查看循环:
与我们相关的要素:
student_count²
。其中
adj_matrix[i,j]
是标量。使用
torch.einsum
这是
torch.einsum
的一个典型用例,您可以阅读更多关于this thread的内容,我碰巧也写过an answer。如果我们不考虑所有的实现细节,那么
torch.einsum
的公式是相当不言自明的:在伪代码中,这可以归结为:
对于您感兴趣域中的所有
i
、j
、m
和f
。结合使用
M = mask[:,None]
扩展到适当形式的掩码,这为学生Tensor提供:对于教师结果,您可以使用
~M
反转掩码:使用
torch.matmul
或者,由于这是
torch.einsum
的一个相当简单的应用程序,因此您也可以通过两次调用torch.matmul
来避免。给定A
和B
,这两个矩阵分别由ik
和kj
索引,您将得到A@B
,它对应于ik@kj -> ij
。因此,您可以通过下式获得所需的结果:看看这两个步骤与torch的关系。einsum调用'ij,mf,jf-〉mi'。首先是
mf,jf->mj
,然后是mj,ij->mi
。边注您当前的虚拟学生和教师邻接矩阵初始化为零。