如何理解tensorflow对象检测Tensor输出?

ppcbkaq5  于 2022-12-19  发布在  其他
关注(0)|答案(2)|浏览(143)

我的动机是构建一个自定义的对象检测web应用程序。我从model zoo下载了一个tf2预训练的SSD Resnet1010模型。我的想法是如果这个实现工作,我将用我自己的数据训练模型。我运行$saved_model_cli show --dir saved_model --tag_set serve --signature_def serving_default来计算输入和输出节点。

The given SavedModel SignatureDef contains the following input(s):
  inputs['input_tensor'] tensor_info:
      dtype: DT_UINT8
      shape: (1, -1, -1, 3)
      name: serving_default_input_tensor:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['detection_anchor_indices'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100)
      name: StatefulPartitionedCall:0
  outputs['detection_boxes'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100, 4)
      name: StatefulPartitionedCall:1
  outputs['detection_classes'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100)
      name: StatefulPartitionedCall:2
  outputs['detection_multiclass_scores'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100, 91)
      name: StatefulPartitionedCall:3
  outputs['detection_scores'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100)
      name: StatefulPartitionedCall:4
  outputs['num_detections'] tensor_info:
      dtype: DT_FLOAT
      shape: (1)
      name: StatefulPartitionedCall:5
  outputs['raw_detection_boxes'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 51150, 4)
      name: StatefulPartitionedCall:6
  outputs['raw_detection_scores'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 51150, 91)
      name: StatefulPartitionedCall:7
Method name is: tensorflow/serving/predict

然后我将模型转换为tensorflowjs模型,方法是运行

tensorflowjs_converter --input_format=tf_saved_model --output_node_names='detection_anchor_indices,detection_boxes,detection_classes,detection_multiclass_scores,detection_scores,num_detections,raw_detection_boxes,raw_detection_scores' --saved_model_tags=serve --output_format=tfjs_graph_model saved_model js_model

这是我的javascript代码(这是在vue方法里面的)

loadTfModel: async function(){
        try {
            this.model = await tf.loadGraphModel(this.MODEL_URL);
        } catch(error) {
            console.log(error);
        }

   },
    predictImg: async function() {
        const imgData = document.getElementById('img');
        let tf_img = tf.browser.fromPixels(imgData);
        tf_img = tf_img.expandDims(0);
        const predictions = await this.model.executeAsync(tf_img);
        const data = []
        for (let i = 0; i < predictions.length; i++){
            data.push(predictions[i].dataSync());
        }
        console.log(data);
    }

输出如下所示:

我的问题是,数组中的这八个项是否对应于八个定义的输出节点?如何理解这些数据?如何将其转换为人类可读的格式,如python格式?

**更新1:**我已经尝试了这个answer并编辑了我的预测方法:

predictImg: async function() {
        const imgData = document.getElementById('img');
        let tf_img = tf.browser.fromPixels(imgData);
        tf_img = tf_img.expandDims(0);
        const predictions = await this.model.executeAsync(tf_img, ['detection_classes']).then(predictions => {
            const data = predictions.dataSync()
            console.log('Predictions: ', data);
        })

    }

我最终得到了,"Error: The output 'detection_classes' is not found in the graph"。我将感激任何帮助。

0pizxfdo

0pizxfdo1#

this.model.executeAsync(tf_img, ['detection_classes'])中指定的输出节点可能有错误。另外,在await this.model.executeAsync(tf_img, ['detection_classes'])中不需要使用await。使用awaitthen
获取detection_classes的另一个选项是对输出数组进行索引:

predictions[i].dataSync()[2]
c9x0cxw0

c9x0cxw02#

我认为你首先需要检查web_model/model.json文件并调查输出的名称。这些是你在过滤要显示的内容时需要使用的文件(下面是我的示例文件)。

相关问题