我基于bert的发布脚本使用了tf.estimator.train_and_evaluate。
if FLAGS.do_train and FLAGS.do_eval:
tf.logging.info("***** Running training and evaluation *****")
tf.logging.info(" Num examples of training data = %d", train_examples_len)
tf.logging.info(" Batch size of training data = %d", FLAGS.train_batch_size)
tf.logging.info(" Num steps of training = %d", num_train_steps)
tf.logging.info(" Num examples of eval data = %d", eval_examples_len)
tf.logging.info(" Batch size of eval data = %d", FLAGS.eval_batch_size)
train_input_fn = file_based_input_fn_builder(
input_file=train_file,
seq_length=FLAGS.max_seq_length,
is_training=True,
drop_remainder=True)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=num_train_steps)
eval_drop_remainder = True if FLAGS.use_tpu else False
eval_input_fn = file_based_input_fn_builder(
input_file=eval_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=eval_drop_remainder)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=500, start_delay_secs=0,
throttle_secs=200)
tf.estimator.train_and_evaluate(estimator, train_spec=train_spec, eval_spec=eval_spec)
我发现训练在GPU上进行,而评估在CPU上进行。如何让评估也在GPU上进行?
2条答案
按热度按时间ct3nt3jp1#
这个问题是否已经解决?(在CPU上进行评估,在GPU上进行训练)。
我在SQUAD上也遇到了同样的问题。请参考我在#75上的评论。
qcuzuvrc2#
我遇到了这个问题,如何解决它?