tensorflow 多输入变压器

ao218c7q  于 2022-11-16  发布在  其他
关注(0)|答案(1)|浏览(175)

我想问一下,有没有人知道如何在一个变压器模型中提供两个序列(多头注意,使其像交叉注意一样工作),我尝试了很多次,但不明白如何执行两个相同类型的输入(类型csv:csv 128、32的数字数据、尺寸,另一个相同)

Model = sequential()

Input1 = tf.keras.input(shape=[128,32])

Input2 = tf.keras.input(shape=[128,32])

Mha = tf.keras.layers.MultiheadAttention(num_heads=2)
Output_tensor = Mha(Input1,Input2)

Retune Model

这只是一个虚拟代码,我从tensorflow理解,如果有人可以提供一个更好的例子,这将是非常有帮助的,我试图执行交叉注意两个输入与多头注意
先谢谢你了

06odsfpq

06odsfpq1#

在官方API文档中可以找到一个交叉关注的例子。https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention

layer = MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.Input(shape=[8, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,
                               return_attention_scores=True)
print(output_tensor.shape) # (None, 8, 16)
print(weights.shape) # (None, 2, 8, 4)

MultiHeadAttention的调用方法中,第一个参数target为query,第二个参数source为value,当key为None(默认值)时,key和value相同,那么当target和source相同时,这就是自关注。

相关问题