MultiHeadAttention, pytorch中存在key_padding_mask,而paddle中没有,paddle中该如何实现?
q35jwt9p1#
在paddle中,要想实现key_padding_mask,需要自己在dataloader组batch时做处理,这部分功能没有耦合到MultiHeadAttention中
k0pti3hp2#
有没有相关操作过的 参考一下
dhxwm5r43#
看一下这个是不是你需要的,https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/tokenizer_utils_fast.py#L245
t40tm48m4#
我也遇到了相似的问题,实际上,可以通过修改attn_mask来间接实现key_padding_mask
假设 Query的长度为L,Key, Value的长度为S,则有
根据paddle的 MultiheadAttnetion API,当attn_mask的值为False时,这个位置的kv会被mask掉,因此,我们可以通过一下计算得到新的attn_mask,以间接实现key_padding_mask:
attn_mask = (attn_mask!=float('-inf'))[None,None,:,:] & key_padding_mask[:,None,None,:]==False
上述产生的attn_mask大小为[B, 1, L, S],后续输入到MultiheadAttention中计算时会经过传播变为 [B, H, L, S], H表示多头注意力中head的个数。
4条答案
按热度按时间q35jwt9p1#
在paddle中,要想实现key_padding_mask,需要自己在dataloader组batch时做处理,这部分功能没有耦合到MultiHeadAttention中
k0pti3hp2#
有没有相关操作过的 参考一下
dhxwm5r43#
看一下这个是不是你需要的,https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/tokenizer_utils_fast.py#L245
t40tm48m4#
我也遇到了相似的问题,实际上,可以通过修改attn_mask来间接实现key_padding_mask
假设 Query的长度为L,Key, Value的长度为S,则有
根据paddle的 MultiheadAttnetion API,当attn_mask的值为False时,这个位置的kv会被mask掉,因此,我们可以通过一下计算得到新的attn_mask,以间接实现key_padding_mask:
attn_mask = (attn_mask!=float('-inf'))[None,None,:,:] & key_padding_mask[:,None,None,:]==False
上述产生的attn_mask大小为[B, 1, L, S],后续输入到MultiheadAttention中计算时会经过传播变为 [B, H, L, S], H表示多头注意力中head的个数。