gpt-2 如何将检查点图冻结为.pb格式?

k97glaaz  于 6个月前  发布在  其他
关注(0)|答案(2)|浏览(172)

试图冻结GPT 2微调模型,但无法弄清楚输出节点名称将是什么。使用此代码作为参考,我把它放在一起:-

import fire
import json
import os
import numpy as np
import tensorflow as tf

import model, sample, encoder

seed=None
length=40
temperature=1
top_k=0

hparams = model.default_hparams()
with open('models/345M/hparams.json') as f:
  hparams.override_from_dict(json.load(f))

with tf.Session(graph=tf.Graph()) as sess:
  context = tf.placeholder(tf.int32, [1, None])
  np.random.seed(seed)
  tf.set_random_seed(seed)
  output = sample.sample_sequence(
      hparams=hparams, length=length,
      context=context,
      batch_size=1,
      temperature=temperature, top_k=top_k
  )

  saver = tf.train.Saver()
  ckpt = tf.train.latest_checkpoint(os.path.join('models', '345M'))
  saver.restore(sess, ckpt)

  print([n.name for n in tf.get_default_graph().as_graph_def().node])

  # Freeze the graph
  frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,[output.name])

  # Save the frozen graph
  with open('output_graph.pb', 'wb') as f:
    f.write(frozen_graph_def.SerializeToString())

但我得到
AssertionError:sample_sequence/while/Exit_3:0不在图形中
那么我应该在freeze_graph中把什么作为参数3输出节点名称呢?

yeotifhr

yeotifhr1#

output.name给你一个Tensor名称(' sample_sequence/while/Exit_3:0')而不是节点名称。我猜你应该把[ 'sample_sequence/while/Exit_3']作为tf.graph_util.convert_variables_to_constants中的参数3

csga3l58

csga3l582#

亲爱的@ChintanTrivedi,你正确冻结了GPT 2模型吗?

相关问题