多头注意力模块的输入和输出不匹配(Tensorflow VS PyTorch)

wfveoks0  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(154)

我正在尝试将我的layers.MultiHeadAttention模块的tensorflow模型从tf.keras转换为nn.MultiheadAttention from torch.nn模块。下面是片段。

  1. Tensorflow多头注意力
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

x_sfe_tf = np.random.randn(64, 345, 64)
x_te_tf = np.random.randn(64, 200, 64)

tes_mod_tf = layers.MultiHeadAttention(num_heads=2, key_dim=64)
output_tf = tes_mod_tf(x_sfe_tf, x_te_tf)

print(output_tf.shape)
  1. PyTorch多头注意力
import torch
import torch.nn as nn

x_sfe_torch = torch.randn(64, 345, 64)
x_te_torch = torch.randn(64, 200, 64)

tes_mod_torch = nn.MultiheadAttention(embed_dim=64, num_heads=2)
output_torch = tes_mod_torch(x_sfe_torch, x_sfe_torch, x_te_torch)
print(output_torch.shape)

当我运行tensorflow的mha时,它成功返回(64, 345, 64)。但是当我运行pytorch的mha时,它返回以下错误:AssertionError: key shape torch.Size([64, 345, 64]) does not match value shape torch.Size([64, 200, 64])
tensorflow版本可以返回大小为x_sfe的输出,忽略其与x_te的大小差异。另一方面,pytorch版本要求x_sfe和x_te必须具有相同的维数。我对tensorflow的多头注意力模块实际上是如何工作的感到困惑?PyTorch和PyTorch之间有什么区别,什么是PyTorch的正确输入?先谢了。

lymnna71

lymnna711#

Tensorflow获取的输入类似于'[batch_size,seq_len,embed_dim]',而Pytorch获取的输入类似于'[seq_len,batch_size,embed_dim]'。您可以使用torch.permute()进行此更改,希望它能解决您的问题。

相关问题