使用Tensorflow沿着后一个维度与自定义图层串联

wr98u20j  于 2022-11-16  发布在  其他
关注(0)|答案(2)|浏览(913)

我正在尝试使用自定义图层将一个数字连接到(None,10,3)Tensor的最后一个维度,使其成为(None,10,4)Tensor。这似乎是不可能的,因为要连接,除了要合并的维度之外,所有维度必须相等,而且我们无法将“None”作为第一个维度来初始化Tensor。
例如,下面的代码给出了以下错误:
第一个
还有别的办法吗?

ua4mk5z4

ua4mk5z41#

您必须确保尊重批维。可能是这样的:

outp = tf.concat([inputs, tf.cast(tf.repeat(self.positional_embeddings_array[None, ...], repeats=tf.shape(inputs)[0], axis=0), dtype=tf.float32)], axis = 2)

同样,tf.shape给出了Tensor的动态形状。

9gm1akwq

9gm1akwq2#

你的问题是连接Tensorshape [10,3]和[10,1],但是你需要执行一个具有特定单位数的Dense函数。你可以将乘法标记为只使用tf.concatenate(),或者将Dense函数更改为具有特定单位数。
示例:位置嵌入不执行连接函数、当前维的尾部或域结果的这两者的传播结果。

import tensorflow as tf

class MyPositionEmbeddedLayer( tf.keras.layers.Concatenate ):
    def __init__( self, units ):
        super(MyPositionEmbeddedLayer, self).__init__( units )
        self.num_units = units

    def build(self, input_shape):
        self.kernel = self.add_weight("kernel",
        shape=[int(input_shape[-1]),
        self.num_units])

    def call(self, inputs, tails):
        ### area to perform pre-calculation or custom algorithms ###
        #                                                          #
        #                                                          #
        ############################################################
        temp = tf.keras.layers.Concatenate(axis=2)([inputs, tails])
        temp = tf.matmul(temp, self.kernel)
        temp = tf.squeeze( temp )
        return temp

#####################################################
        
start = 3
limit = 93
delta = 3
sample = tf.range(start, limit, delta)
sample = tf.cast( sample, dtype=tf.float32 )
sample = tf.constant( sample, shape=( 10, 1, 3, 1 ) )

start = 3
limit = 33
delta = 3
tails = tf.range(start, limit, delta)
tails = tf.cast( tails, dtype=tf.float32 )
tails = tf.constant( tails, shape=( 10, 1, 1, 1 ) )

layer = MyPositionEmbeddedLayer(10)

print( layer(sample, tails) )

输出:您可以看到它学习密集内核,近邻频率别名。

...
 [[-26.67632     35.44779     23.239683    20.374893   -12.882696
    54.963055   -18.531412    -4.589509   -21.722694   -43.44675   ]
  [-27.629044    36.713783    24.069672    21.102568   -13.3427925
    56.92602    -19.193249    -4.7534204  -22.498507   -44.99842   ]
  [-28.58177     37.979774    24.89966     21.830242   -13.802889
    58.88899    -19.855083    -4.917331   -23.274317   -46.55009   ]
  [ -9.527256    12.6599245    8.299887     7.276747    -4.600963
    19.629663    -6.6183615   -1.6391104   -7.7581053  -15.516697  ]]], shape=(10, 4, 10), dtype=float32)

相关问题