tensorflow 数据.数据集.rejection_resample修改了我的数据集的element_spec

aemubtdh  于 2023-02-13  发布在  其他
关注(0)|答案(2)|浏览(88)
    • bounty将在2天后过期**。回答此问题可获得+100声望奖励。Alberto A希望引起更多人关注此问题。

我尝试使用tf.data.Dataset.rejection_resample来平衡我的数据集,但是我遇到了一个问题,即该方法修改了我的数据集的element_spec,使其与我的模型不兼容。
我的数据集的原始元素规范是:

({'input_A': TensorSpec(shape=(None, 900, 1), dtype=tf.float64, name=None),
  'input_B': TensorSpec(shape=(None, 900, 1), dtype=tf.float64, name=None)},
 TensorSpec(shape=(None, 1, 1), dtype=tf.int64, name=None))

这是批处理后的元素规格。
但是,如果我运行rejection_resample(在批处理之前),则末尾的元素规范变为:

(TensorSpec(shape=(None,), dtype=tf.int64, name=None),
 ({'input_A': TensorSpec(shape=(None, 900, 1), dtype=tf.float64, name=None),
   'input_B': TensorSpec(shape=(None, 900, 1), dtype=tf.float64, name=None)},
  TensorSpec(shape=(None, 1, 1), dtype=tf.int64, name=None)))

rejection_resample在我的数据开始处添加了另一个tf.int64Tensor,我不知道它是做什么用的。我的问题是这破坏了输入数据和我的模型之间的兼容性,因为它依赖于原始输入元组。
此外,它还会导致训练数据和验证数据之间的不一致。我希望只对训练数据应用rejection_resample,但如果这样做,训练数据集将添加Tensor,而验证数据集则不会。
所以我的问题是,元素规范中添加的Tensor是什么,以及在构建元素后,是否有任何方法可以将其从数据集中"删除"。

zkure5ic

zkure5ic1#

我不能告诉你添加的Tensor来自哪里,但这里有一个如何从数据集中删除/删除它的示例:

import tensorflow as tf
import numpy as np
# creating a sample dataset that's similar to your 'wrong' output
ds = tf.data.Dataset.from_tensor_slices((np.arange(-10, 0),(tf.constant(np.arange(10)), tf.constant(np.arange(10,20)))))
# remove the new 'wrong' tensor
dds = ds.map(lambda x, y: y)
# check new dataset
for i in dds.take(2):
    print(i)

请记住,这是一种变通方法,不会移除导致附加Tensor的源

lztngnrs

lztngnrs2#

假设我创建了与您的数据集相同的数据集,

x = tf.random.normal((7000, 900,1))
y = tf.random.normal((7000, 900,1))
z = tf.random.uniform((7000, 1,1), 1, 2, dtype=tf.int32)

#Now converting it to Tf.Dataset object
dataset = tf.data.Dataset.from_tensor_slices(((x,y),z))

func = lambda x , y : (({'input_A' : x[0], 'input_B' : x[1]}), y)
dataset = dataset.map(func)

Map后,我的dataset将与您的完全相同

<MapDataset element_spec=({'input_A': TensorSpec(shape=(900, 1), dtype=tf.float32, name=None), 'input_B': TensorSpec(shape=(900, 1), dtype=tf.float32, name=None)}, TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))>

现在,我必须删除最后一个Tensor

disjoint_func = lambda x , y :(x)
dataset = dataset.map(disjoint_func)

现在,extra Tensor已被删除

<MapDataset element_spec={'input_A': TensorSpec(shape=(900, 1), dtype=tf.float32, name=None), 'input_B': TensorSpec(shape=(900, 1), dtype=tf.float32, name=None)}>

相关问题