tensorflow版本(使用下面的命令):2.3.0
pythonversion:3.7
我试图在java中使用tensorflow模型,我将一个文本分类模型(带有tf.lookup)转换为fomat.pb,并希望在java中加载它,但出现“table not initialized”错误。
2021-01-04 14:00:10.713588: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: Table not initialized.
[[{{node graph/hash_table_Lookup/LookupTableFindV2}}]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:326)
at org.tensorflow.Session$Runner.run(Session.java:276)
at ctest.Ttest.predict(Ttest.java:32)
at ctest.Ttest.main(Ttest.java:13)
这是我的code:in python
import os
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.ops.lookup_ops import HashTable, KeyValueTensorInitializer
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
OUTPUT_FOLDER = ''
OUTPUT_NAME = 'hash_table.pb'
OUTPUT_NAMES = ['graph/output', 'init_all_tables']
def build_graph():
d = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
init = KeyValueTensorInitializer(list(d.keys()), list(d.values()))
hash_table = HashTable(init, default_value=-1)
data = tf.placeholder(tf.string, (None,), name='data')
values = hash_table.lookup(data)
output = tf.identity(values * 2, 'output')
def freeze_graph():
with tf.Graph().as_default() as graph:
with tf.name_scope('graph'):
build_graph()
with tf.Session(graph=graph) as sess:
sess.run(tf.tables_initializer())
print(sess.run('graph/output:0', feed_dict={'graph/data:0': ['a', 'b', 'c', 'd', 'e']}))
frozen_graph = convert_variables_to_constants(sess, sess.graph_def, OUTPUT_NAMES)
tf.train.write_graph(frozen_graph, OUTPUT_FOLDER, OUTPUT_NAME, as_text=False)
def load_frozen_graph():
with open(os.path.join(OUTPUT_FOLDER, OUTPUT_NAME), 'rb') as f:
output_graph_def = tf.GraphDef()
output_graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(output_graph_def, name='')
with tf.Session(graph=graph) as sess:
try:
sess.run(graph.get_operation_by_name('init_all_tables'))
except KeyError:
pass
print(sess.run('graph/output:0', feed_dict={'graph/data:0': ['a', 'b', 'c', 'd', 'e']}))
if __name__ == '__main__':
freeze_graph()
load_frozen_graph()
在java中
package ctest;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.nio.file.Files;
import java.nio.file.Paths;
public class Ttest {
public static void main(String[] args) throws Exception {
predict();
}
public static void predict() throws Exception {
try (Graph graph = new Graph()) {
graph.importGraphDef(Files.readAllBytes(Paths.get(
"/opt/resources/hash_table.pb"
)));
try (Session sess = new Session(graph)) {
byte[][] matrix = new byte[1][];
matrix[0] = "a".getBytes("UTF-8");
Tensor< ? > out = sess.runner()
.feed("graph/data:0", Tensor.create(matrix)).fetch("graph/output:0").run().get(0);
float[][] output = new float[1][(int) out.shape()[1]];
out.copyTo(output);
for(float i:output[0])
System.out.println(i);
}
}
}
}
如有任何建议,将不胜感激。
暂无答案!
目前还没有任何答案,快来回答吧!