我尝试使用tf.data.Dataset.rejection_resample
来平衡我的数据集,但是我遇到了一个问题,即该方法修改了我的数据集的element_spec
,使其与我的模型不兼容。
我的数据集的原始元素规范是:
({'input_A': TensorSpec(shape=(None, 900, 1), dtype=tf.float64, name=None),
'input_B': TensorSpec(shape=(None, 900, 1), dtype=tf.float64, name=None)},
TensorSpec(shape=(None, 1, 1), dtype=tf.int64, name=None))
这是批处理后的元素规格。
但是,如果我运行rejection_resample
(在批处理之前),则末尾的元素规范变为:
(TensorSpec(shape=(None,), dtype=tf.int64, name=None),
({'input_A': TensorSpec(shape=(None, 900, 1), dtype=tf.float64, name=None),
'input_B': TensorSpec(shape=(None, 900, 1), dtype=tf.float64, name=None)},
TensorSpec(shape=(None, 1, 1), dtype=tf.int64, name=None)))
rejection_resample
在我的数据开始处添加了另一个tf.int64
Tensor,我不知道它是做什么用的。我的问题是这破坏了输入数据和我的模型之间的兼容性,因为它依赖于原始输入元组。
此外,它还会导致训练数据和验证数据之间的不一致。我希望只对训练数据应用rejection_resample
,但如果这样做,训练数据集将添加Tensor,而验证数据集则不会。
所以我的问题是,元素规范中添加的Tensor是什么,以及在构建元素后,是否有任何方法可以将其从数据集中"删除"。
2条答案
按热度按时间zkure5ic1#
我不能告诉你添加的Tensor来自哪里,但这里有一个如何从数据集中删除/删除它的示例:
请记住,这是一种变通方法,不会移除导致附加Tensor的源
lztngnrs2#
假设我创建了与您的数据集相同的数据集,
Map后,我的
dataset
将与您的完全相同现在,我必须删除最后一个
Tensor
现在,
extra Tensor
已被删除