pytorch中多注意力的重构Tensor- view与transpose

fiei3ece  于 2022-11-29  发布在  其他
关注(0)|答案(1)|浏览(145)

我正在学习深度学习领域中的注意力算子。我了解到,要高效并行计算多个头部注意力,必须对输入Tensor(querykeyvalue)进行适当整形。假设 querykeyvalue 是形状相同的三个Tensor[N, L, D],其中

  • N 为批量
  • L 为序列长度
  • D 是隐藏/嵌入大小,

它们应该被转换成[N*N_H, L, D_H]Tensor,其中 N_H 是关注层的头部的数目,并且 D_H 是每个头部的嵌入大小。
pytorch代码似乎就是这样做的。下面我发布了重塑查询Tensor的代码(键、值被同等地视为)

q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)

我不明白他们为什么同时执行viewtranspose调用,而执行

q = q.contiguous().view(bsz * num_heads, tgt_len, head_dim)

除了避免额外的函数调用外,单独使用view还可以保证得到的Tensor在内存中仍然是连续的,而transpose就不一样了(据我所知)。我认为使用连续数据是有益的,只要可能,可以使计算速度更快(可能导致更少的内存访问,更好地利用数据的空间局部性,等等)。
view之后进行transpose调用的用例是什么?

mzillmmw

mzillmmw1#

结果未必相同:

a = torch.arange(0, 2 * 3 * 4)
b = a.view(2, 3, 4).transpose(1, 0)
#tensor([[[ 0,  1,  2,  3],
     [12, 13, 14, 15]],

    [[ 4,  5,  6,  7],
     [16, 17, 18, 19]],

    [[ 8,  9, 10, 11],
     [20, 21, 22, 23]]])

c = a.view(3, 2, 4)
#tensor([[[ 0,  1,  2,  3],
     [ 4,  5,  6,  7]],

    [[ 8,  9, 10, 11],
     [12, 13, 14, 15]],

    [[16, 17, 18, 19],
     [20, 21, 22, 23]]])

相关问题