如何在Tensorflow中更改已保存的模型输入形状?

goqiplq2  于 2022-11-16  发布在  其他
关注(0)|答案(2)|浏览(166)

我想让这个repo https://github.com/ildoonet/tf-pose-estimation与Intel Movidius一起运行,所以我尝试使用mvNCCompile转换pb模型。
问题是mvNCCompile需要一个固定的输入形状,但我拥有的模型是动态的。
我试过这个

graph_path = 'models/graph/mobilenet_thin/graph_opt.pb'

    with tf.gfile.GFile(graph_path, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    graph = tf.get_default_graph()
    tf.import_graph_def(graph_def, name='TfPoseEstimator')
    x = graph.get_tensor_by_name('TfPoseEstimator/image:0')
    x.set_shape([1, 368, 368, 3])
    x = graph.get_tensor_by_name('TfPoseEstimator/MobilenetV1/Conv2d_0/Conv2D:0')
    x.set_shape([1, 368, 368, 24])

得到了这个

(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_0/weights:0' shape=(3, 3, 3, 24) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/image:0' shape=(1, 368, 368, 3) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_0/Conv2D:0' shape=(1, 368, 368, 24) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_0/Conv2D_bn_offset:0' shape=(24,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 24) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_0/Relu:0' shape=(?, ?, ?, 24) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_depthwise/depthwise_weights:0' shape=(3, 3, 24, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_pointwise/weights:0' shape=(1, 1, 24, 48) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_depthwise/depthwise:0' shape=(?, ?, ?, 24) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_pointwise/Conv2D:0' shape=(?, ?, ?, 48) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_pointwise/Conv2D_bn_offset:0' shape=(48,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 48) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_pointwise/Relu:0' shape=(?, ?, ?, 48) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_depthwise/depthwise_weights:0' shape=(3, 3, 48, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_pointwise/weights:0' shape=(1, 1, 48, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_depthwise/depthwise:0' shape=(?, ?, ?, 48) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_pointwise/Conv2D:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_pointwise/Conv2D_bn_offset:0' shape=(96,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_pointwise/Relu:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_depthwise/depthwise_weights:0' shape=(3, 3, 96, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_pointwise/weights:0' shape=(1, 1, 96, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_depthwise/depthwise:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_pointwise/Conv2D:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_pointwise/Conv2D_bn_offset:0' shape=(96,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_pointwise/Relu:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_depthwise/depthwise_weights:0' shape=(3, 3, 96, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_pointwise/weights:0' shape=(1, 1, 96, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_depthwise/depthwise:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_pointwise/Conv2D:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_pointwise/Conv2D_bn_offset:0' shape=(192,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_pointwise/Relu:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_depthwise/depthwise_weights:0' shape=(3, 3, 192, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_pointwise/weights:0' shape=(1, 1, 192, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_depthwise/depthwise:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_pointwise/Conv2D:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_pointwise/Conv2D_bn_offset:0' shape=(192,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_pointwise/Relu:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_depthwise/depthwise_weights:0' shape=(3, 3, 192, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_pointwise/weights:0' shape=(1, 1, 192, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_depthwise/depthwise:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_depthwise/depthwise_weights:0' shape=(3, 3, 384, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_pointwise/weights:0' shape=(1, 1, 384, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_depthwise/depthwise:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_depthwise/depthwise_weights:0' shape=(3, 3, 384, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_pointwise/weights:0' shape=(1, 1, 384, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_depthwise/depthwise:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_depthwise/depthwise_weights:0' shape=(3, 3, 384, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_pointwise/weights:0' shape=(1, 1, 384, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_depthwise/depthwise:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_depthwise/depthwise_weights:0' shape=(3, 3, 384, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_pointwise/weights:0' shape=(1, 1, 384, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_depthwise/depthwise:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_depthwise/depthwise_weights:0' shape=(3, 3, 384, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_pointwise/weights:0' shape=(1, 1, 384, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_depthwise/depthwise:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)

TfPoseEstimator/image:0TfPoseEstimator/MobilenetV1/Conv2d_0/Conv2D:0之外的其他层仍具有?形状。
我是Tensorflow的新手,所以这可能是个愚蠢的问题,但如何更改已保存模型的输入形状呢?

t8e9dugd

t8e9dugd1#

我设法用这个解决了这个问题。

import tensorflow as tf
if __name__ == '__main__':
    graph_path = 't/tf_model.pb'
    with tf.gfile.GFile(graph_path, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    graph = tf.get_default_graph()
    tf_new_image = tf.placeholder(shape=(1, 368, 368, 3), dtype='float32', name='new_image')
    tf.import_graph_def(graph_def, name='TfPoseEstimator', input_map={"image:0": tf_new_image})
    tf.train.write_graph(graph, "t", "mobilenet_thin_model.pb", as_text=False)
uxhixvfz

uxhixvfz2#

使用tf2.x,我相信您可以将其更改为具体函数:

imported = tf.saved_model.load('/path/to/saved_model')
concrete_func = imported.signatures["serving_default"]
concrete_func.inputs[0].set_shape([1, 368, 368, 3])

相关问题