tensorflow 如何使用predict_on_batch避免DataGenerator出现GPU内存不足错误

n7taea2i  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(114)

我有一个Keras模型,它由两个部分组成(左和右),两者通常独立工作(由于实际情况),但这些部分在中间步骤中交换模型生成的一些潜在数据。我想用自动编码器压缩这些潜在数据。因此,我在模型中引入了自动编码器作为额外的子模型。为了训练这些自动编码器,我使用类似于

submodel = tf.keras.Model(inputs=[model.input], outputs=[model.get_layer(submodel_name).output])

这是工作正常。然而,我现在通过预测整个数据集来生成潜在数据,我需要使用L2损失来训练我的自动编码器:

train = submodel.predict(train_ds)

然而,由于原始数据集并不小,而且子模型的输出维度相当大,因此在为我的一个子模型运行这一行时,我耗尽了GPU内存。训练自动编码器的整个过程如下:

submodel = tf.keras.Model(inputs=[model.input], outputs=[model.get_layer(submodel_name).output])
    AE_name = 'AE_' + submodel_name
    AE = model.get_layer(AE_name)
    
    train = submodel.predict(train_ds)        
    valid = submodel.predict(valid_ds)
    
    AE.bypass = False
    AE.compile(loss='mse', run_eagerly=False)

    AE.fit(x=train,y=train, validation_data=(valid,valid),
                    epochs=epochs, verbose=0, callbacks = [callbacks[0]])

最初,为了生成数据,自动编码器被设置为旁路(1:1Map),这样我就可以在没有自动编码器的情况下获得原始模型的正确潜在数据。
如何将预测拆分为更小的步骤,从而不需要那么多GPU RAM?我的问题是,我对用于生成训练和验证数据的tensorflow Datagenerator类相当不熟悉。到目前为止,我的尝试失败了。
为了给予上下文,用于创建train_ds和valid_ds的datagenerator使用以下函数来编码并(随后)获取数据:

def fetch(self):
        dataset = tf.data.TFRecordDataset(self.tfr).map(self._decode,
                                                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
        if self.mode == "train":
            dataset = dataset.shuffle(2000, reshuffle_each_iteration=False) 
            train_dataset = dataset.batch(1, drop_remainder=True)#dataset.batch(self.batch_size, drop_remainder=True)
            train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
            return train_dataset
        
        if self.mode == "valid":
            valid_dataset = dataset.batch(1, drop_remainder=False)
            valid_dataset = valid_dataset.prefetch(tf.data.experimental.AUTOTUNE) 
            return valid_dataset
        
        else:
            dataset = dataset.batch(1, drop_remainder=True)
            dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
            return dataset

    def _encode(self, mode):                 
        writer = tf.io.TFRecordWriter(self.tfr)
        
        if self.mode != "test":
            mix_filenames = glob.glob(os.path.join(self.wav_dir, "*mixed*.wav"))
            target_filenames = glob.glob(os.path.join(self.wav_dir, "*target*.wav"))
            sys.stdout.flush()  
            
            for mix_filename, target_filename in tqdm(zip(mix_filenames, 
                                                          target_filenames), total = len(mix_filenames)):
                mix, _ = librosa.load(mix_filename, self.sample_rate, mono = False)
                clean, _ = librosa.load(target_filename, self.sample_rate, mono = False)
    
                def write(a, b):
                    example = tf.train.Example(
                        features=tf.train.Features(
                            feature={
                                "noisy_left" : self._float_list_feature(mix[0, a:b]),
                                "noisy_right": self._float_list_feature(mix[1, a:b]),
                                "clean_left" : self._float_list_feature(clean[0, a:b]),
                                "clean_right": self._float_list_feature(clean[1, a:b])}))
                    
                    writer.write(example.SerializeToString())
                
                now_length = mix.shape[-1]
                target_length = int(self.duration * self.sample_rate)
    
                if now_length < target_length:
                    continue 
                
                stride = int(self.duration * self.sample_rate)
                for i in range(0, now_length - target_length, stride):
                    write(i, i + target_length)

我尝试迭代train_ds,并在每个元素上独立地调用模型(而不是预测),这个想法是我用一个epoch的每个预测来训练自动编码器。然而,我注意到,当我调用model. predict时,model(SingleElement)会产生一个非常轻微的(大约10 - 15)的模型输出差异。原因可能是使用了一些规范化层(我检查了工作中没有dropout)。因为我不想冒错过这些模型细节的风险(我从一个同事那里得到了代码),所以我更愿意避免这种方法,并以某种方式对批进行预测。然而,这是我无法去工作。

whlutmcx

whlutmcx1#

我解决了我的问题,切换到CPU的RAM密集型任务。这是相当慢,但在我的情况下不太慢,所以我可以腾出时间。我切换到CPU使用

with tf.device('/cpu:0'):
    TrainModelOnLargeData

相关问题