Paddle MultiHeadAttention的key_padding_mask

bwitn5fc  于 5个月前  发布在  其他
关注(0)|答案(4)|浏览(55)

请提出你的问题 Please ask your question

MultiHeadAttention, pytorch中存在key_padding_mask,而paddle中没有,paddle中该如何实现?

q35jwt9p

q35jwt9p1#

在paddle中,要想实现key_padding_mask,需要自己在dataloader组batch时做处理,这部分功能没有耦合到MultiHeadAttention中

k0pti3hp

k0pti3hp2#

有没有相关操作过的 参考一下

dhxwm5r4

dhxwm5r43#

看一下这个是不是你需要的,https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/tokenizer_utils_fast.py#L245

t40tm48m

t40tm48m4#

我也遇到了相似的问题,实际上,可以通过修改attn_mask来间接实现key_padding_mask

假设 Query的长度为L,Key, Value的长度为S,则有

  • attn_mask([L, S], paddle.float,其中值为 float('-inf') 表示要被mask掉的位置)
  • key_padding_mask([N, S], paddle.bool,其中值为 True 表示为padding的位置,也即kv要被mask掉的位置)

根据paddle的 MultiheadAttnetion API,当attn_mask的值为False时,这个位置的kv会被mask掉,因此,我们可以通过一下计算得到新的attn_mask,以间接实现key_padding_mask:

attn_mask = (attn_mask!=float('-inf'))[None,None,:,:] & key_padding_mask[:,None,None,:]==False

上述产生的attn_mask大小为[B, 1, L, S],后续输入到MultiheadAttention中计算时会经过传播变为 [B, H, L, S], H表示多头注意力中head的个数。

相关问题