如何在TensorFlow中将_dims扩展未知次数?

55ooxyrt  于 2022-12-23  发布在  其他
关注(0)|答案(4)|浏览(122)

如果一个元素的维数是已知的,那么我就可以调用tf.expand_dims。如何将tf.expand_dims放入循环中?下面的代码可以在eager模式下工作,但不能在graph模式下工作。

# @tf.function
def broadcast_multiply(x, y):
    print(tf.shape(x)) # [2, 2, ?, ?, ... ?]
    print(tf.shape(y)) # [2, 2]

    # Doesnt work in graph mode but works in eager
    rank_diff = tf.rank(x) - tf.rank(y)
    for _ in tf.range(rank_diff):
        y = tf.expand_dims(y, -1)

    return x * y
5m1hhzi4

5m1hhzi41#

你应该看看tf.broadcast_to

def broadcast_multiply(x, y):
    y = tf.broadcast_to(y, tf.shape(x))
    return x * y
zzlelutf

zzlelutf2#

经过多次的头部撞击,这就是我想出的。不是最好的性能,但它做的工作。我希望tensorflow有内置的支持这一点。Numpy已经这样做了。

@tf.function
def match_shapes(x, y):
    # Find which one needs to be broadcasted
    low, high = (y, x) if tf.rank(x) > tf.rank(y) else (x, y)
    l_rank, l_shape = tf.rank(low), tf.shape(low)
    h_rank, h_shape = tf.rank(high), tf.shape(high)
    
    # Find the difference in ranks
    common_shape = h_shape[:l_rank]
    tf.debugging.assert_equal(common_shape, l_shape, 'No common shape to broadcast')
    padding = tf.ones(h_rank - l_rank, dtype=tf.int32)
    
    # Pad the difference with ones and reshape
    new_shape = tf.concat((common_shape, padding),axis=0)
    low = tf.reshape(low, new_shape)

    return high, low

@tf.function
def broadcast_multiply(x, y):
    x, y = match_shapes(x, y)
    return x * y
    
x = tf.ones((3, 3, 2)) * 3
y = tf.ones((3, 3)) * 2
broadcast_multiply(x, y)

结果

<tf.Tensor: shape=(3, 3, 2), dtype=float32, numpy=
array([[[6., 6.],
        [6., 6.],
        [6., 6.]],

       [[6., 6.],
        [6., 6.],
        [6., 6.]],

       [[6., 6.],
        [6., 6.],
        [6., 6.]]], dtype=float32)>
dsf9zpds

dsf9zpds3#

我遇到了类似的问题,下面的解决方案对我很有效:

rank_diff = tf.rank(x) - tf.rank(y)
y = y[(...,) + rank_diff * (tf.newaxis,)]
3zwjbxry

3zwjbxry4#

您还可以使用整形:

tf.reshape(y, shape=tf.concat([tf.shape(y),  tf.ones(tf.rank(x) - 2)], 0))

相关问题