Tensorflow:如何使用tf.roll而不进行 Package ?

lskq00tm  于 2022-11-25  发布在  其他
关注(0)|答案(2)|浏览(137)

我想独立移动二维Tensor的列或行,如下所示:

a = tf.constant([[1,2,3], [4,5,6]])
shift = tf.constant([2, -1])
b = shift_fn(a, shift)

它给出:

b = [[0, 0, 1], [5, 6, 0]]

我发现tf.roll()也可以做类似的事情,但是会 Package 元素。我如何使用它来填充零呢?

4xrmg8kj

4xrmg8kj1#

一个不太好的解决方案是首先使用tf.pad填充Tensor,然后在tf.map_fn中使用tf.roll独立地移动填充Tensor的每一行(或列)。最后,您可以对结果进行适当的切片。例如:

a = tf.constant([[1,2,3], [4,5,6]])
shift = tf.constant([2, -1])

cols = a.shape[1]
paddings = tf.constant([[0, 0], [cols, cols]])
padded_a = tf.pad(a, paddings)

tf.map_fn(
    lambda x: tf.roll(x[0], x[1], axis=-1),
    elems=(padded_a, shift),
    # The following argument is required here
    fn_output_signature=a.dtype
)[:,cols:cols*2]

"""
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[0, 0, 1],
       [5, 6, 0]], dtype=int32)>
"""

或者,为了帮助保存一些内存,填充和切片都可以在传递给tf.map_fnfn函数内完成。

erhoui1w

erhoui1w2#

@tf.function
def shift(inputs, shift, axes, pad_val = 0):
    assert len(shift) == len(axes)
    axis_shift = zip(axes, shift)
    axis2shift = dict(axis_shift)  
    old_shape = inputs.shape

    for axis in axis2shift:
        pad_shape = list(inputs.shape)
        pad_shape[axis] = abs(axis2shift[axis])
        input_pad = tf.fill(pad_shape, pad_val)
        inputs = tf.concat((inputs, input_pad), axis = axis) 
    
    
    input_roll = tf.roll(inputs, shift, axes)
    ret = tf.slice(input_roll, [0 for _ in range(len(old_shape))], old_shape)

    return ret

上面的函数用tf.roll类接口实现了shift *(shift和axis应该是列表而不是Tensor)。
希望这对你有帮助

相关问题