Tensorflow多探头注意输入:4 x 5 x 20 x 64,其中attention_axis =2抛出掩码尺寸错误(tf 2.11.0)

nx7onnlm  于 2022-11-30  发布在  其他
关注(0)|答案(1)|浏览(111)

这里的期望是注意力应用在第二维(4,5,20,64)。我尝试使用以下代码应用自我注意力(此代码可重现问题):

import numpy as np
import tensorflow as tf
from keras import layers as tfl

class Encoder(tfl.Layer):
    def __init__(self,):
        super().__init__()
        self.embed_layer = tfl.Embedding(4500, 64, mask_zero=True)
        self.attn_layer = tfl.MultiHeadAttention(num_heads=2,
                                                 attention_axes=2,
                                                 key_dim=16)
        return

    def call(self, x):
        # Input shape: (4, 5, 20) (Batch size: 4)
        x = self.embed_layer(x)  # Output: (4, 5, 20, 64)
        x = self.attn_layer(query=x, key=x, value=x)  # Output: (4, 5, 20, 64)
        return x

eg_input = tf.constant(np.random.randint(0, 150, (4, 5, 20)))
enc = Encoder()
enc(eg_input)

然而,上面定义的层引发了以下错误。有人能解释为什么会发生这种情况以及如何修复这种情况吗?

{{function_node __wrapped__AddV2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [4,5,2,20,20] vs. [4,5,1,5,20] [Op:AddV2]

Call arguments received by layer 'softmax_2' (type Softmax):
  • inputs=tf.Tensor(shape=(4, 5, 2, 20, 20), dtype=float32)
  • mask=tf.Tensor(shape=(4, 5, 1, 5, 20), dtype=bool)

PS:如果我在定义嵌入层时设置mask_zero = False,代码运行正常,没有任何问题。

elcex8rz

elcex8rz1#

只需将输入沿着axis=0连接即可
第一个

相关问题