tensorflow在java中加载模型时“表未初始化”

2izufjch  于 2021-06-26  发布在  Java
关注(0)|答案(0)|浏览(590)

tensorflow版本(使用下面的命令):2.3.0
pythonversion:3.7
我试图在java中使用tensorflow模型,我将一个文本分类模型(带有tf.lookup)转换为fomat.pb,并希望在java中加载它,但出现“table not initialized”错误。

2021-01-04 14:00:10.713588: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: Table not initialized.
     [[{{node graph/hash_table_Lookup/LookupTableFindV2}}]]
    at org.tensorflow.Session.run(Native Method)
    at org.tensorflow.Session.access$100(Session.java:48)
    at org.tensorflow.Session$Runner.runHelper(Session.java:326)
    at org.tensorflow.Session$Runner.run(Session.java:276)
    at ctest.Ttest.predict(Ttest.java:32)
    at ctest.Ttest.main(Ttest.java:13)

这是我的code:in python

import os
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.ops.lookup_ops import HashTable, KeyValueTensorInitializer

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
OUTPUT_FOLDER = ''
OUTPUT_NAME = 'hash_table.pb'
OUTPUT_NAMES = ['graph/output', 'init_all_tables']

def build_graph():
    d = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
    init = KeyValueTensorInitializer(list(d.keys()), list(d.values()))
    hash_table = HashTable(init, default_value=-1)
    data = tf.placeholder(tf.string, (None,), name='data')
    values = hash_table.lookup(data)
    output = tf.identity(values * 2, 'output')

def freeze_graph():
    with tf.Graph().as_default() as graph:
        with tf.name_scope('graph'):
            build_graph()

        with tf.Session(graph=graph) as sess:
            sess.run(tf.tables_initializer())
            print(sess.run('graph/output:0', feed_dict={'graph/data:0': ['a', 'b', 'c', 'd', 'e']}))
            frozen_graph = convert_variables_to_constants(sess, sess.graph_def, OUTPUT_NAMES)
            tf.train.write_graph(frozen_graph, OUTPUT_FOLDER, OUTPUT_NAME, as_text=False)

def load_frozen_graph():
    with open(os.path.join(OUTPUT_FOLDER, OUTPUT_NAME), 'rb') as f:
        output_graph_def = tf.GraphDef()
        output_graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(output_graph_def, name='')
        with tf.Session(graph=graph) as sess:
            try:
                sess.run(graph.get_operation_by_name('init_all_tables'))
            except KeyError:
                pass
            print(sess.run('graph/output:0', feed_dict={'graph/data:0': ['a', 'b', 'c', 'd', 'e']}))

if __name__ == '__main__':
    freeze_graph()
    load_frozen_graph()

在java中

package ctest;

import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.nio.file.Files;
import java.nio.file.Paths;

public class Ttest {
    public static void main(String[] args) throws Exception {
        predict();
    }
    public static void predict() throws Exception {
        try (Graph graph = new Graph()) {
            graph.importGraphDef(Files.readAllBytes(Paths.get(
                    "/opt/resources/hash_table.pb"
            )));
            try (Session sess = new Session(graph)) {
                byte[][] matrix = new byte[1][];
                matrix[0] = "a".getBytes("UTF-8");
                Tensor< ? > out = sess.runner()
                        .feed("graph/data:0", Tensor.create(matrix)).fetch("graph/output:0").run().get(0);
                float[][] output = new float[1][(int) out.shape()[1]];
                out.copyTo(output);
                for(float i:output[0])
                    System.out.println(i);

            }
        }
    }
}

如有任何建议,将不胜感激。

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题