在TensorFlow数据管道中使用不规则Tensor时发生TypeError

46qrfjad  于 2023-10-23  发布在  其他
关注(0)|答案(1)|浏览(86)

我尝试在TensorFlow中使用RaggedTensors创建tf.data管道,但我一直遇到TypeError:

"Cannot convert the argument `type_value`: RaggedTensorSpec(TensorShape(\[None, 3\]), tf.int16, 1, tf.int64) to a TensorFlow DType."

我在下面提供了我的代码以供参考:

def data_generator():
    for label, file in enumerate(['a.npy', 'b.npy']):
        yield generate_sample(file), label

def generate_sample(file_path):
    sequence = load_npy_file(file_path)
    return tf.cast(tf.ragged.constant(generate_windows(sequence)), dtype=tf.int16)

def load_npy_file(file_path):
    data = np.load(file_path, allow_pickle=True)
    return data.astype(np.int16)

def generate_windows(sequence):
    windows = [sequence[i:i + 2] for i in range(len(sequence) - 2 + 1)]
    return np.array(windows, dtype=np.int16)

np.save('a.npy', np.array([1,2,3,4], dtype=np.int16))
np.save('b.npy', np.array([5,6,7,8,9], dtype=np.int16))
output_signature = (
        tf.RaggedTensorSpec(shape=(None, 2), dtype=tf.int16),
        tf.TensorSpec(shape=(), dtype=tf.int32)
)
dataset = tf.data.Dataset.from_generator(data_generator, output_signature)

在检查了生成的RaggedTensors的形状后,我注意到它们与output_signature中定义的形状不同:RaggedTensorSpec(TensorShape([3, None]), tf.int16, 1, tf.int32)RaggedTensorSpec(TensorShape([4, None]), tf.int16, 1, tf.int32)。这种差异会导致错误吗?我将非常感谢任何见解或解决方案来解决这个问题。谢谢你,谢谢!

tyky79it

tyky79it1#

最后一行中对from_generator()的调用似乎有一个参数不匹配。从我的Angular 来看,这解决了问题:

dataset = tf.data.Dataset.from_generator(data_generator, output_signature=output_signature)

请参阅函数的文档。第二个参数是output_types,不是签名。

相关问题