有没有比使用整形更好的方法在Tensor内重复层?

8yoxcaq7  于 2021-07-14  发布在  Java
关注(0)|答案(0)|浏览(192)

我想把形状的Tensor[batch size,1]变成[batch size,channel size]。为此,我编写了以下算法:

@tf.function
def fit_channel_layers(input_tensor, target_tensor):
    final_channel_layers = tf.shape(target_tensor)[-1]
    input_tensor = tf.repeat(input_tensor, repeats=[final_channel_layers])
    input_tensor = tf.reshape(input_tensor, [local_batch_size, final_channel_layers])
    return input_tensor

input_tensor = tf.convert_to_tensor([[1.],[0.],[0.],[1.],[0.],[0.]])
target_tensor = tf.convert_to_tensor([[1.,1.,1.,1.,1.,1.],
 [0.,0.,0.,0.,0.,0.],
 [0.,0.,0.,0.,0.,0.],
 [1.,1.,1.,1.,1.,1.],
 [0.,0.,0.,0.,0.,0.],
 [0.,0.,0.,0.,0.,0.]])
fit_channel_layers(input_tensor, target_tensor)

输出是适当的

array([[1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]], dtype=float32)>

但是,我不确定使用重塑对性能的影响。我知道使用numpy可以使用逗号和乘法器运算符轻松地扩展numpy数组。但是,我想知道是否有比目前用@tf.function图形代码编写的更好的方法来完成这个任务?

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题