Tensorflow2 C++ model_pruner失败:无效参数:图形不包含终端节点detection_boxes

o8x7eapl  于 2023-06-07  发布在  其他
关注(0)|答案(1)|浏览(275)

我通过python导出savemodel.pb,通过C++ API加载。这是我的加载代码

string strModelPath = "/home/hxy/tf/objectrec/train/train/saved_model/";
string inputLayer = "input_tensor:0";
    vector<string> outputLayer = {"detection_boxes", "detection_classes", "detection_scores", "num_detections"};

tensorflow::SessionOptions session_option;
tensorflow::RunOptions run_option;
tensorflow::SavedModelBundle iBundle
Status load_graph_status = tensorflow::LoadSavedModel(session_option, run_option, strModelDir, {"serve"}, &iBundle);
if (!load_graph_status.ok())
{
    return;
}
cv::Mat frame = cv::imread("/home/hxy/tf/objectrec/test_images/raccoon_dt2.jpg");
Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, frame.rows, frame.cols, frame.channels()}));
float* p = input_tensor.flat<float>().data();
cv::Mat m_input(frame.rows, frame.cols, CV_32FC1, p);
frame.convertTo(m_input, CV_32FC1);

std::vector<tensorflow::Tensor> outputs;
outputs.clear();
Status runStatus = session->Run({{inputLayer, input_tensor},}, outputLayer, {}, &outputs);
...

获取错误消息:

2020-07-30 18:25:23.941979: I tensorflow/cc/saved_model/loader.cc:303] SavedModel load for tags { serve }; Status: success: OK. Took 5927735 microseconds.
2020-07-30 18:25:24.226638: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:581] model_pruner failed: Invalid argument: Graph does not contain terminal node detection_boxes.

我想可能是输出层名称错误,所以我运行saved_model_cli看到模型:

The given SavedModel SignatureDef contains the following output(s):
  outputs['detection_boxes'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 300, 4)
      name: StatefulPartitionedCall:1
  outputs['detection_classes'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 300)
      name: StatefulPartitionedCall:2
  outputs['detection_scores'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 300)
      name: StatefulPartitionedCall:4
  outputs['num_detections'] tensor_info:
      dtype: DT_FLOAT
      shape: (1)
      name: StatefulPartitionedCall:5

我看到有检测框,请给予我一些建议,谢谢

2vuwiymt

2vuwiymt1#

我自己解决了它,inputLayeroutputLayer名称来自saved_model_cli信息。在我Pro中,它是serving_default_input_tensor:0StatefulPartitionedCall:1(or x)
如果你得到error : invalid argument specified in either feed_devices or fetch_devices was not found in the graph。可能是输入层名称错误,答案在上面,从saved_model_cli看到输入名称。但这是造成这个错误的一个原因,也许你可以找到其他线索,祝你好运。
如果你有error : Expects arg[0] to be uint8 but float is provided。检查您的Tensor数据类型,必须与saved_model_cli匹配给予dtype
您可以从https://www.tensorflow.org/guide/saved_model了解saved_model_cli
好运

相关问题