自定义tensorflow 指标

ql3eal8s  于 2022-12-27  发布在  其他
关注(0)|答案(1)|浏览(178)

我目前为train_and_evaluate进程配置了eval_metric_ops

eval_metric_ops = {"mean_squared_error": tf.compat.v1.metrics.mean_squared_error(
             labels=features['image'],
             predictions= model.denoise_fn(normalize_data(features['image']), features['label'])),
                   }

我的损失是这样定义的:

def meanflat(x):
  return tf.reduce_mean(x, axis=list(range(1, len(x.shape))))

loss = nn.meanflat(tf.squared_difference(noise, x_recon))

如何将tf.squared_difference应用于eval_metric_ops定义?

5f0d552i

5f0d552i1#

可以将tf.math.squared_difference传递给loss参数

def meanflat(x):
  return tf.reduce_mean(x, axis=list(range(1, len(x.shape))))

loss = meanflat(tf.math.squared_difference(x,y))

eval_metric_ops = {'loss': loss}

有关详细信息,请参阅此gist。谢谢。

相关问题