tensorflow 如何转置不规则Tensor

t3irkdon  于 2023-10-23  发布在  其他
关注(0)|答案(1)|浏览(142)

一切都在标题中。
我有一个形状为(nsamples, None, M1)的RaggedTensor,我想转置最后两个轴以获得形状为(nsamples, M1, None)的RaggedTensor。
tf.transpose不适用于不规则Tensor,tf.linalg.matrix_transpose也不适用(因为它们是同一个函数)。
有没有一种方法可以完全在Tensorflow中完成它,或者我必须手动执行一个(缓慢的)for循环?

11dmarpk

11dmarpk1#

首先,我希望这个答案可以帮助一些人(包括我自己)在未来,尽管已经晚了很多年的原始问题。

交换轴1和轴2的不规则Tensorragged_rank=1

对于手头的问题,可以用tf.map_fn()实现所需的转置。

def swap_dims_1_2(rt: tf.RaggedTensor):
    dtype = rt.flat_values.dtype
    
    def _fn(_t):
        return tf.RaggedTensor.from_tensor(tf.transpose(_t, [1, 0]))
    
    output_shape = (rt.flat_values.shape[1], None)
    return tf.map_fn(_fn, rt, fn_output_signature=tf.RaggedTensorSpec(shape=output_shape, dtype=dtype))

这里有一个小测试:

data = [
    [[1, 2]], 
    [[3, 4], [5, 6], [7, 8], [9, 10]], 
    [[11, 12], [13, 14]],
]

rt: tf.RaggedTensor = tf.ragged.constant(data, ragged_rank=1)
print(rt.shape)  # (3, None, 2)
print(rt.nested_row_lengths())

r2 = swap_dims_1_2(rt)
print(r2)
print(r2.shape)  # (3, 2, None)
print(r2.nested_row_lengths())

我们也可以将不规则的Tensor填充为规则Tensor,然后转置,这可能不可取,但可以用来检查上面函数的正确性:

# ... 
# `rt` was created same as above 

t = rt.to_tensor(default_value=0)  # shape (3, 4, 2)
t_transpose = tf.transpose(t, perm=[0, 2, 1])  # shape (3, 2, 4)

rt_transpose = tf.ragged.boolean_mask(t_transpose, mask=t_transpose>0)
rt_transpose._set_shape([None, 2, None])
print(rt_transpose)
print(rt_transpose.shape)  # (3, 2, None)
print(rt_transpose.nested_row_lengths())

转置不规则Tensor的注意事项

一般来说,在不填充某些填充值的情况下,不可能在不规则Tensor中转置维度,特别是当涉及多个不规则维度时。我们可以这样想:一个Tensor,不管是否不规则,都是列表的嵌套列表。转置Tensor意味着重新排序嵌套的级别,并且结果Tensor的内部列表可以以统一的步幅包含原始Tensor的元素。这对于不规则的Tensor来说是有问题的,因为“统一步幅的元素”在某些维度上没有很好的定义。
因此,当考虑“转置不规则Tensor”或“交换不规则Tensor中的轴”时,最好写下一个例子,看看假设的转置操作是否真的达到了预期的效果。

相关问题