pytorch 训练时自定义注意力功能缓慢

3pvhb19x  于 2023-11-19  发布在  其他
关注(0)|答案(1)|浏览(92)

我一直在尝试在一个标准的gpt中实现一个自定义的注意函数-2风格的Transformer模型。我用负欧几里德距离替换了缩放点积,除了训练非常慢之外,一切似乎都在工作。使用正常的点积注意力,模型在几分钟内训练。在我的实现中,似乎至少需要一天的时间来训练。我使用的数据集大约是30 MB我对Pytorch不是很熟悉,所以我不知道我的实现是否尽可能高效。
下面是我对自定义注意力功能的尝试。

def CustomAttention(A: Float[Tensor, "batch posn_q n_heads d_head"], 
                    B: Float[Tensor, "batch posn_k n_heads d_head"]) -> Float[Tensor, "batch n_heads posn_q posn_k"]:
    A_cast = t.permute(A, (0, 2, 1, 3)).unsqueeze(-2)
    B_cast = t.permute(B, (0, 2, 1, 3)).unsqueeze(-3)
    diff = A_cast - B_cast
    square = diff**2
    sum = t.sum(square, dim=-1)
    return -sum

字符串
基本上,代码依赖于广播来计算每个查询键对的元素差异。我所有的训练都是在3080 ti上本地完成的,它似乎像预期的那样使用了100%的gpu。我能做些什么来让它运行得更快吗?

juud5qan

juud5qan1#

当你说它训练得很慢时,你的意思是每批次的时间变差了,还是每批次的损失改进变差了?如果是后者,这并不完全令人惊讶。负欧几里得距离注意力并不像标度点积注意力那样产生相同的训练动态。arXiv上最近的预印本在不同的环境中记录了这种现象:
C McCarter,“Inverse distance weighting attention.“联想记忆和Hopfield网络研讨会@ NeurIPS,2023年。

相关问题