Tensorflow批量标量乘法

tcomlyy6  于 2023-10-23  发布在  其他
关注(0)|答案(1)|浏览(76)

我想创建一个tensorflow层,它接受n个输入,n-1是数据点,最后一个是长度为n-1的权重向量。(i in n-1)然后将数据点乘以存储在权重向量的索引i中的值。然后将加权数据点的结果累加并作为单个数据点返回。我遇到的问题是在TensorFlow中高效地执行此操作,由于额外的批次维度,我最终需要将n-1个形状Tensor(None 9,256,256,1)乘以n-1个形状Tensor(None,)。
为了实现这一点,我尝试使用函数tf.math.multiply(tensor, weights)tf.math.scalar_mul(weights1, tensor1)tf.linalg.matvec(tensor1, weights1),代码如下:

class LinearCombination(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(LinearCombination, self).__init__(**kwargs)

    def build(self, input_shape):
        # Ensure that the input shape matches the expected shape
        print(input_shape)

        num_tensors = len(input_shape[0])
        num_weights = input_shape[1][1]
        assert num_tensors == num_weights, f"Number of tensors{num_tensors} and number of weights{num_weights} must match."

        super(LinearCombination, self).build(input_shape)  # Be sure to call this at the end

    def call(self, inputs):
        # Multiply each tensor by its corresponding weight and sum them up
        print("\n\n\n\nlayer called")
        tensors = inputs[0]
        tensor1 = tensors[0]
        tensor2 = tensors[1]
        tensor3 = tensors[2]

        print(f"\ntensors: {tensors}")
        print(f"t1{tensor1}")
        print(f"t2{tensor2}")
        print(f"t3{tensor3}\n\n")

        weights = inputs[1]
        weights1 = weights[:, 0]
        #weights1 = tf.expand_dims(tf.expand_dims(tf.expand_dims(weights1, axis=-1), axis=-1), axis=-1)

        weights2 = weights[:, 1]
        #weights2 = tf.expand_dims(tf.expand_dims(tf.expand_dims(weights2, axis=-1), axis=-1), axis=-1)

        weights3 = weights[:, 2]
        #weights3 = tf.expand_dims(tf.expand_dims(tf.expand_dims(weights3, axis=-1), axis=-1), axis=-1)

        print(f"\n\nweights: {weights}")
        print(f"w1{weights1}")
        print(f"w2{weights2}")
        print(f"w3{weights3}\n\n")


        #tensor1 = tf.math.multiply(tensor1, weights1)
        #tensor2 = tf.math.multiply(tensor2, weights2)
        #tensor3 = tf.math.multiply(tensor3, weights3)
        #tensor1 = tf.math.scalar_mul(weights1, tensor1)
        #tensor2 = tf.math.scalar_mul(weights2, tensor2)
        #tensor3 = tf.math.scalar_mul(weights3, tensor3)
        #tensor1 = tf.linalg.matvec(tensor1, weights1)
        #tensor2 = tf.linalg.matvec(tensor2, weights2)
        #tensor3 = tf.linalg.matvec(tensor3, weights3)
        print(f"\nresults:")
        print(f"r1{tensor1}")
        print(f"r2{tensor2}")
        print(f"r3{tensor3}\n\n")

        out = tf.math.add(tensor1, tensor2)
        out = tf.math.add(tensor3, out)
        print(f"final output: {out}")

        return out

    def compute_output_shape(self, input_shape):
        return input_shape[0][1]  # Output shape matches the shape of each input tensor

所有的乘法都被注解掉了,Tensor保持正确的形状,调用函数的print语句如下:

tensors: [<tf.Tensor 'Placeholder:0' shape=(None, 9, 256, 256, 1) dtype=float32>, <tf.Tensor 'Placeholder_1:0' shape=(None, 9, 256, 256, 1) dtype=float32>, <tf.Tensor 'Placeholder_2:0' shape=(None, 9, 256, 256, 1) dtype=float32>]
t1Tensor("Placeholder:0", shape=(None, 9, 256, 256, 1), dtype=float32)
t2Tensor("Placeholder_1:0", shape=(None, 9, 256, 256, 1), dtype=float32)
t3Tensor("Placeholder_2:0", shape=(None, 9, 256, 256, 1), dtype=float32)



weights: Tensor("Placeholder_3:0", shape=(None, 3), dtype=float32)
w1Tensor("linear_combination/strided_slice:0", shape=(None,), dtype=float32)
w2Tensor("linear_combination/strided_slice_1:0", shape=(None,), dtype=float32)
w3Tensor("linear_combination/strided_slice_2:0", shape=(None,), dtype=float32)


results:
r1Tensor("Placeholder:0", shape=(None, 9, 256, 256, 1), dtype=float32)
r2Tensor("Placeholder_1:0", shape=(None, 9, 256, 256, 1), dtype=float32)
r3Tensor("Placeholder_2:0", shape=(None, 9, 256, 256, 1), dtype=float32)

final output: Tensor("linear_combination/Add_1:0", shape=(None, 9, 256, 256, 1), dtype=float32)

如果我们不把三个Tensor硬编码的话,那就更好了。

ippsafx7

ippsafx71#

可以使用tf.einsum来实现这样的机动。下面是更新后的call()函数:

def call(self, inputs):
        # Multiply each tensor by its corresponding weight and sum them up
        print("\n\n\n\nlayer called")
        tensors = inputs[0]  # List of M tensor with shape (BATCH, 9, 256, 256, 1)
        weights = inputs[1]  # Tensor with shape (BATCH, M)
        
        # [(BATCH, 1, 9, 256, 256, 1)] ==concat=> (BATCH, M, 9, 256, 256, 1)
        big_tensor = tf.concat([tf.expand_dims(t, axis=1) for t in tensors], axis=1)
        out = tf.einsum('nmabcd,nm->nabcd', big_tensor, weights)
        
        print(f"final output: {out}")

        return out

相关问题