将序列填充为数字的倍数TensorFlow

s3fp2yjn  于 2023-03-19  发布在  其他
关注(0)|答案(1)|浏览(122)

我想在Tensorflow中将序列填充为数字的倍数。我尝试的代码是:

import tensorflow as tf
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
    seqlen = tf.shape(tensor)[dim]
    seqlen = tf.cast(seqlen, tf.float32)
    multiple = tf.cast(multiple, tf.float32)
    m = seqlen / multiple
    if tf.math.equal(tf.math.floor(m), m):
        return False, tensor
    remainder = tf.math.ceil(m) * multiple - seqlen
    paddings = tf.zeros([tf.rank(tensor), 2], dtype=tf.int32)
    paddings = tf.tensor_scatter_nd_update(paddings, tf.reshape([dim, 1], [1, 2]), tf.reshape([remainder, 0], [1, 2]))
    tensor = tf.pad(tensor, paddings, constant_values=value)
    return True, tensor

它在tf.tensor_scatter_nd_update处抛出以下错误:

Inner dimensions of output shape must match inner dimensions of updates shape. Output: [4,2] updates: [1,2]

有没有办法修复tensor_scatter_nd_update中的尺寸不匹配问题?

4dbbbstv

4dbbbstv1#

存在维度不匹配,这是引发错误的原因
请尝试以下方法,看看是否可以解决此问题:

import tensorflow as tf

def pad_to_multiple(tensor, multiple, dim=-1, value=0):
    seqlen = tf.shape(tensor)[dim]
    seqlen = tf.cast(seqlen, tf.float32)
    multiple = tf.cast(multiple, tf.float32)
    m = seqlen / multiple
    if tf.math.equal(tf.math.floor(m), m):
        return False, tensor
    remainder = tf.math.ceil(m) * multiple - seqlen
    paddings = tf.zeros([tf.rank(tensor), 2], dtype=tf.int32)
    paddings = tf.expand_dims(paddings, axis=0)  
    paddings = tf.tensor_scatter_nd_update(paddings, tf.reshape([dim, 1], [1, 2]), tf.reshape([remainder, 0], [1, 2]))
    paddings = tf.squeeze(paddings, axis=0)  
    tensor = tf.pad(tensor, paddings, constant_values=value)
    return True, tensor

相关问题