试图冻结GPT 2微调模型,但无法弄清楚输出节点名称将是什么。使用此代码作为参考,我把它放在一起:-
import fire
import json
import os
import numpy as np
import tensorflow as tf
import model, sample, encoder
seed=None
length=40
temperature=1
top_k=0
hparams = model.default_hparams()
with open('models/345M/hparams.json') as f:
hparams.override_from_dict(json.load(f))
with tf.Session(graph=tf.Graph()) as sess:
context = tf.placeholder(tf.int32, [1, None])
np.random.seed(seed)
tf.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
batch_size=1,
temperature=temperature, top_k=top_k
)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', '345M'))
saver.restore(sess, ckpt)
print([n.name for n in tf.get_default_graph().as_graph_def().node])
# Freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,[output.name])
# Save the frozen graph
with open('output_graph.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
但我得到
AssertionError:sample_sequence/while/Exit_3:0不在图形中
那么我应该在freeze_graph中把什么作为参数3输出节点名称呢?
2条答案
按热度按时间yeotifhr1#
output.name给你一个Tensor名称(' sample_sequence/while/Exit_3:0')而不是节点名称。我猜你应该把[ 'sample_sequence/while/Exit_3']作为tf.graph_util.convert_variables_to_constants中的参数3
csga3l582#
亲爱的@ChintanTrivedi,你正确冻结了GPT 2模型吗?