有没有方法访问tensorflow_hub.KerasLayer对象中的层?

qhhrdooz  于 2023-06-23  发布在  其他
关注(0)|答案(4)|浏览(126)

我正在尝试使用来自tensorflow hub的预训练模型到我的对象检测模型中。我按照官方说明将hub中的模型 Package 为KerasLayer对象。然后我意识到我无法访问这个预训练模型中的层。但是我需要使用一些特定层的输出来构建我的模型。有没有方法访问tensorflow_hub.KerasLayer对象中的层?

dgsult0t

dgsult0t1#

有一种未记录的方法可以从TF-Slim导出的某些TF 2 SavedModels中获取中间层,例如https://tfhub.dev/google/imagenet/inception_v1/feature_vector/4:将return_endpoints=True传递给SavedModel的__call__函数会将输出更改为dict
注:此接口可能会更改或删除,并且存在已知问题。

model = tfhub.KerasLayer('https://tfhub.dev/google/imagenet/inception_v1/feature_vector/4', trainable=False, arguments=dict(return_endpoints=True))
input = tf.keras.layers.Input((224, 224, 3))
outputs = model(input)
for k, v in sorted(outputs.items()):
  print(k, v.shape)

此示例的输出:

InceptionV1/Conv2d_1a_7x7 (None, 112, 112, 64)
InceptionV1/Conv2d_2b_1x1 (None, 56, 56, 64)
InceptionV1/Conv2d_2c_3x3 (None, 56, 56, 192)
InceptionV1/MaxPool_2a_3x3 (None, 56, 56, 64)
InceptionV1/MaxPool_3a_3x3 (None, 28, 28, 192)
InceptionV1/MaxPool_4a_3x3 (None, 14, 14, 480)
InceptionV1/MaxPool_5a_2x2 (None, 7, 7, 832)
InceptionV1/Mixed_3b (None, 28, 28, 256)
InceptionV1/Mixed_3c (None, 28, 28, 480)
InceptionV1/Mixed_4b (None, 14, 14, 512)
InceptionV1/Mixed_4c (None, 14, 14, 512)
InceptionV1/Mixed_4d (None, 14, 14, 512)
InceptionV1/Mixed_4e (None, 14, 14, 528)
InceptionV1/Mixed_4f (None, 14, 14, 832)
InceptionV1/Mixed_5b (None, 7, 7, 832)
InceptionV1/Mixed_5c (None, 7, 7, 1024)
InceptionV1/global_pool (None, 1, 1, 1024)
default (None, 1024)

需要注意的问题:

  • 未记录,可能会更改或删除,无法始终提供。
  • __call__计算所有输出(并在训练期间应用所有更新操作),而不管稍后使用的输出。

来源:https://github.com/tensorflow/hub/issues/453

kuuvgm7e

kuuvgm7e2#

因为return_endpoints=True似乎不再工作了。
你可以这样做:

efficientnet_lite0_base_layer = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2",
    output_shape=[1280],
    trainable=False
)

print("Thickness of the model:", len(efficientnet_lite0_base_layer.weights))
print ("{:<80} {:<20} {:<10}".format('Layer','Shape','Type'))

for i in range(len(efficientnet_lite0_base_layer.weights)):
    model_weights_raw_string = str(efficientnet_lite0_base_layer.weights[i])
    model_weights_wo_weights = model_weights_raw_string.split(", numpy", 1)[0]
    dtype = model_weights_wo_weights.split(" dtype=")[1]
    shape = model_weights_wo_weights.split(" shape=")[1].split(" dtype=")[0]
    
    print ("{:<80} {:<20} {:<10}".format(efficientnet_lite0_base_layer.weights[i].name, shape, dtype))
wnvonmuf

wnvonmuf3#

为了能够轻松地做到这一点,预训练模型的创建者需要使输出准备好被访问。例如,通过使用额外的函数或额外的签名来输出您想要使用的激活。

7ajki6be

7ajki6be4#

这并不能给予您以编程方式访问这些层,但它确实允许您检查它们。

import tensorflow as tf
import tensorflow_hub as hub

resnet_v2 = hub.load(os.path.join(tfhub_dir, 'imagenet_resnet_v2_50_classification_5'))

print(tf.__version__)
resnet_v2.summary()
single_keras_layer = resnet_v2.layers[0]
variables = single_keras_layer.variables

for i, v in enumerate(variables):
    print('[{:03d}] {} [{}]'.format(i, v.name, v.shape))

输出量

2.13.0
Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 keras_layer_6 (KerasLayer)  (None, 1001)              25615849  
                                                                 
=================================================================
Total params: 25615849 (97.72 MB)
Trainable params: 0 (0.00 Byte)
Non-trainable params: 25615849 (97.72 MB)
_________________________________________________________________
[000] resnet_v2_50/block2/unit_1/bottleneck_v2/shortcut/biases:0 [(512,)]
[001] resnet_v2_50/block2/unit_4/bottleneck_v2/conv1/BatchNorm/gamma:0 [(128,)]
[002] resnet_v2_50/block3/unit_1/bottleneck_v2/conv2/weights:0 [(3, 3, 256, 256)]
[003] resnet_v2_50/block4/unit_1/bottleneck_v2/conv3/biases:0 [(2048,)]
[004] resnet_v2_50/block1/unit_1/bottleneck_v2/shortcut/biases:0 [(256,)]
[005] resnet_v2_50/block3/unit_2/bottleneck_v2/preact/gamma:0 [(1024,)]
[006] resnet_v2_50/block3/unit_3/bottleneck_v2/conv1/BatchNorm/gamma:0 [(256,)]
[007] resnet_v2_50/block4/unit_3/bottleneck_v2/conv1/BatchNorm/gamma:0 [(512,)]
[008] resnet_v2_50/block1/unit_1/bottleneck_v2/preact/gamma:0 [(64,)]
[009] resnet_v2_50/block1/unit_2/bottleneck_v2/conv3/weights:0 [(1, 1, 64, 256)]
[010] resnet_v2_50/block2/unit_1/bottleneck_v2/preact/gamma:0 [(256,)]
[011] resnet_v2_50/block2/unit_1/bottleneck_v2/conv2/BatchNorm/gamma:0 [(128,)]
[012] resnet_v2_50/block2/unit_3/bottleneck_v2/conv3/biases:0 [(512,)]
...
[268] resnet_v2_50/block4/unit_1/bottleneck_v2/preact/moving_variance:0 [(1024,)]
[269] resnet_v2_50/block4/unit_1/bottleneck_v2/conv2/BatchNorm/moving_variance:0 [(512,)]
[270] resnet_v2_50/block2/unit_2/bottleneck_v2/conv1/BatchNorm/moving_variance:0 [(128,)]
[271] resnet_v2_50/block1/unit_3/bottleneck_v2/preact/moving_mean:0 [(256,)]

相关问题