我目前为train_and_evaluate
进程配置了eval_metric_ops
:
eval_metric_ops = {"mean_squared_error": tf.compat.v1.metrics.mean_squared_error(
labels=features['image'],
predictions= model.denoise_fn(normalize_data(features['image']), features['label'])),
}
我的损失是这样定义的:
def meanflat(x):
return tf.reduce_mean(x, axis=list(range(1, len(x.shape))))
loss = nn.meanflat(tf.squared_difference(noise, x_recon))
如何将tf.squared_difference
应用于eval_metric_ops
定义?
1条答案
按热度按时间5f0d552i1#
可以将tf.math.squared_difference传递给loss参数
有关详细信息,请参阅此gist。谢谢。