我通过python导出savemodel.pb,通过C++ API加载。这是我的加载代码
string strModelPath = "/home/hxy/tf/objectrec/train/train/saved_model/";
string inputLayer = "input_tensor:0";
vector<string> outputLayer = {"detection_boxes", "detection_classes", "detection_scores", "num_detections"};
tensorflow::SessionOptions session_option;
tensorflow::RunOptions run_option;
tensorflow::SavedModelBundle iBundle
Status load_graph_status = tensorflow::LoadSavedModel(session_option, run_option, strModelDir, {"serve"}, &iBundle);
if (!load_graph_status.ok())
{
return;
}
cv::Mat frame = cv::imread("/home/hxy/tf/objectrec/test_images/raccoon_dt2.jpg");
Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, frame.rows, frame.cols, frame.channels()}));
float* p = input_tensor.flat<float>().data();
cv::Mat m_input(frame.rows, frame.cols, CV_32FC1, p);
frame.convertTo(m_input, CV_32FC1);
std::vector<tensorflow::Tensor> outputs;
outputs.clear();
Status runStatus = session->Run({{inputLayer, input_tensor},}, outputLayer, {}, &outputs);
...
获取错误消息:
2020-07-30 18:25:23.941979: I tensorflow/cc/saved_model/loader.cc:303] SavedModel load for tags { serve }; Status: success: OK. Took 5927735 microseconds.
2020-07-30 18:25:24.226638: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:581] model_pruner failed: Invalid argument: Graph does not contain terminal node detection_boxes.
我想可能是输出层名称错误,所以我运行saved_model_cli看到模型:
The given SavedModel SignatureDef contains the following output(s):
outputs['detection_boxes'] tensor_info:
dtype: DT_FLOAT
shape: (1, 300, 4)
name: StatefulPartitionedCall:1
outputs['detection_classes'] tensor_info:
dtype: DT_FLOAT
shape: (1, 300)
name: StatefulPartitionedCall:2
outputs['detection_scores'] tensor_info:
dtype: DT_FLOAT
shape: (1, 300)
name: StatefulPartitionedCall:4
outputs['num_detections'] tensor_info:
dtype: DT_FLOAT
shape: (1)
name: StatefulPartitionedCall:5
我看到有检测框,请给予我一些建议,谢谢
1条答案
按热度按时间2vuwiymt1#
我自己解决了它,
inputLayer
和outputLayer
名称来自saved_model_cli
信息。在我Pro中,它是serving_default_input_tensor:0
和StatefulPartitionedCall:1(or x)
。如果你得到
error : invalid argument specified in either feed_devices or fetch_devices was not found in the graph
。可能是输入层名称错误,答案在上面,从saved_model_cli
看到输入名称。但这是造成这个错误的一个原因,也许你可以找到其他线索,祝你好运。如果你有
error : Expects arg[0] to be uint8 but float is provided
。检查您的Tensor数据类型,必须与saved_model_cli
匹配给予dtype
您可以从https://www.tensorflow.org/guide/saved_model了解
saved_model_cli
好运