我正在尝试将我的layers.MultiHeadAttention
模块的tensorflow模型从tf.keras
转换为nn.MultiheadAttention
from torch.nn
模块。下面是片段。
- Tensorflow多头注意力
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
x_sfe_tf = np.random.randn(64, 345, 64)
x_te_tf = np.random.randn(64, 200, 64)
tes_mod_tf = layers.MultiHeadAttention(num_heads=2, key_dim=64)
output_tf = tes_mod_tf(x_sfe_tf, x_te_tf)
print(output_tf.shape)
- PyTorch多头注意力
import torch
import torch.nn as nn
x_sfe_torch = torch.randn(64, 345, 64)
x_te_torch = torch.randn(64, 200, 64)
tes_mod_torch = nn.MultiheadAttention(embed_dim=64, num_heads=2)
output_torch = tes_mod_torch(x_sfe_torch, x_sfe_torch, x_te_torch)
print(output_torch.shape)
当我运行tensorflow的mha时,它成功返回(64, 345, 64)
。但是当我运行pytorch的mha时,它返回以下错误:AssertionError: key shape torch.Size([64, 345, 64]) does not match value shape torch.Size([64, 200, 64])
tensorflow版本可以返回大小为x_sfe的输出,忽略其与x_te的大小差异。另一方面,pytorch版本要求x_sfe和x_te必须具有相同的维数。我对tensorflow的多头注意力模块实际上是如何工作的感到困惑?PyTorch和PyTorch之间有什么区别,什么是PyTorch的正确输入?先谢了。
1条答案
按热度按时间lymnna711#
Tensorflow获取的输入类似于'[batch_size,seq_len,embed_dim]',而Pytorch获取的输入类似于'[seq_len,batch_size,embed_dim]'。您可以使用torch.permute()进行此更改,希望它能解决您的问题。