我想创建一个tf.keras回调函数,以保存训练期间每个批次和每个时期的模型预测
我已经尝试了下面的回调,但是它给出了如下错误
AttributeError: 'PredictionCallback' object has no attribute 'X_train'
我的密码是
class PredictionCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.X_train)
print('prediction: {} at epoch: {}'.format(y_pred, epoch))
pd.DataFrame(y_pred).assign(epoch=epoch).to_csv('{}_{}.csv'.format(filename, epoch))
cnn_model.fit(X_train, y_train,validation_data=[X_valid,y_valid],epochs=epochs,batch_size=batch_size,
callbacks=[model_checkpoint,reduce_lr,csv_logger, early_stopping,PredictionCallback()],
verbose=1)
我也试过Create keras callback to save model predictions and targets for each batch during training,但还没有成功。希望Maven能帮助我。谢谢。
2条答案
按热度按时间ws51t4hk1#
我试过了
但是它没有将结果保存在文件名中,为什么?
eaf3rand2#
你好,你的思路是正确的。你可以使用下面的回调函数通过txt文件存储它:
之后,您可以使用tensorflow的
fit
函数训练模型:在定义了您的体系结构之后:
model = network()
这对我很有效。看看你是否也在你的文件夹的正确路径。