是否有适当的方法来子类化Tensorflow的数据集?

x759pob2  于 2022-11-16  发布在  其他
关注(0)|答案(1)|浏览(183)

我一直在研究自定义Tensorflow数据集的不同方法,我习惯于查看PyTorch的数据集,但当我查看Tensorflow's datasets时,我看到了以下示例:

class ArtificialDataset(tf.data.Dataset):
  def _generator(num_samples):
    # Opening the file
    time.sleep(0.03)

    for sample_idx in range(num_samples):
      # Reading data (line, record) from the file
      time.sleep(0.015)

      yield (sample_idx,)

  def __new__(cls, num_samples=3):
    return tf.data.Dataset.from_generator(
        cls._generator,
        output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
        args=(num_samples,)
        )

但两个问题出现了:
1.看起来它所做的只是当对象被示例化时,__new__方法只调用tf.data.Dataset.from_generator静态方法。那么为什么不直接调用它呢?为什么还要对tf.data.Dataset进行子类化呢?有没有从tf.data.Dataset中使用的方法?
1.是否有一种方法可以像数据生成器那样完成它,在这种方法中,用户在从tf.data.Dataset继承的同时填写__iter__方法?

class MyDataLoader(tf.data.Dataset):
  def __init__(self, path, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.data = pd.read_csv(path)

  def __iter__(self):
    for datum in self.data.iterrows():
      yield datum

非常感谢大家!

83qze16e

83qze16e1#

问题1

该示例只是将数据集与生成器封装在类中。它继承自tf.data.Dataset,因为from_generator()返回基于tf.data.Dataset的对象。但是,没有使用tf.data.Dataset的方法,如示例中所示。因此,问题1的答案:是的,它可以直接调用而不使用类。

问题2

是的。可以这样做。
另一种类似的方法是像here一样使用tf.keras.utils.Sequence

相关问题