我想在Tensorflow中将序列填充为数字的倍数。我尝试的代码是:
import tensorflow as tf
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
seqlen = tf.shape(tensor)[dim]
seqlen = tf.cast(seqlen, tf.float32)
multiple = tf.cast(multiple, tf.float32)
m = seqlen / multiple
if tf.math.equal(tf.math.floor(m), m):
return False, tensor
remainder = tf.math.ceil(m) * multiple - seqlen
paddings = tf.zeros([tf.rank(tensor), 2], dtype=tf.int32)
paddings = tf.tensor_scatter_nd_update(paddings, tf.reshape([dim, 1], [1, 2]), tf.reshape([remainder, 0], [1, 2]))
tensor = tf.pad(tensor, paddings, constant_values=value)
return True, tensor
它在tf.tensor_scatter_nd_update
处抛出以下错误:
Inner dimensions of output shape must match inner dimensions of updates shape. Output: [4,2] updates: [1,2]
有没有办法修复tensor_scatter_nd_update
中的尺寸不匹配问题?
1条答案
按热度按时间4dbbbstv1#
存在维度不匹配,这是引发错误的原因
请尝试以下方法,看看是否可以解决此问题: