keras:将图层添加到其他模型

yruzcnhs  于 2022-11-24  发布在  其他
关注(0)|答案(2)|浏览(177)

我需要在一个已有的模型中添加层。但是,我需要在“主模型级别”添加层,也就是说,我不能使用经典的函数方法。例如,如果我使用类似如下的方法:

from keras.layers import Dense,Reshape, Input
inp = Input(shape=(15,))
d1 = Dense(224*224*3, activation='linear')(inp)
r1 = Reshape(input_shape)
from keras import Model
model_mod = r1(d1)
model_mod = mobilenet(model_mod)
model_mod = Model(inp, model_mod)

本人获得:

Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         (None, 15)                0         
_________________________________________________________________
dense_4 (Dense)              (None, 150528)            2408448   
_________________________________________________________________
reshape_4 (Reshape)          (None, 224, 224, 3)       0         
_________________________________________________________________
mobilenet_1.00_224 (Model)   (None, 1000)              4253864

因此,我得到了一个带有嵌套子模型的模型。相反,我将嵌套子模型的层(mobilenet)“添加”到新的顶层(即,在reforme_4之后)。我尝试了:

modelB_input = modelB.input
for layer in modelB.layers:
    if layer == modelB_input:
        continue
    modelA.add(layer)

它适用于简单的顺序模型(例如,vgg,mobilenet),但对于连接不是严格顺序的更复杂模型(例如,inception,resnet),这段代码就不好了。有什么想法吗?

7d7tgy0s

7d7tgy0s1#

您可以使用keras.layers.Concatenate来合并两个模型,如下所示:

first = Sequential()
first.add(Dense(1, input_shape=(2,), activation='sigmoid'))

second = Sequential()
second.add(Dense(1, input_shape=(1,), activation='sigmoid'))
 
merged = Concatenate([first, second])

(摘自:(第10页)
虽然此示例使用keras.models.Sequential,但它也适用于其他模型或层。
您还可以查看:https://keras.io/api/layers/merging_layers/concatenate/

x9ybnkn6

x9ybnkn62#

如果你想在已有模型的B层中添加一个A层,你可以把B层输出到A层,然后用tf.keras.model.Model解析到一个新的模型中。这个方法的全面演示在用于目标检测或分割的特征提取器中。你可以在这里找到一个
例如,通过在VGG 16模型底部添加2个新层

full_vgg_model = tf.keras.applications.VGG16(
                            include_top=False,
                            weights="imagenet",
                            input_tensor=None,
                            input_shape=None,
                            pooling=None,
                            classes=1000,
                        )

当前图层:

Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, None, None, 64)    1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, None, None, 64)    36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, None, None, 64)    0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, None, None, 128)   73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, None, None, 128)   147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, None, None, 128)   0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, None, None, 256)   295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, None, None, 256)   0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, None, None, 512)   1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, None, None, 512)   0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, None, None, 512)   0         
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0

然后我附加两个新图层:

conv6 = tf.keras.layers.Conv2D(1024, 3, strides=(1, 1), padding='same', activation='relu', dilation_rate=(6,6), name='conv6')(full_vgg_model.layers[-1].output)

conv7 = tf.keras.layers.Conv2D(1024, 1, strides=(1, 1), padding='same', activation='relu', name='conv7')(conv6)
    
classification_backbone = tf.keras.Model(
            inputs=full_vgg_model.inputs,
            outputs=[conv6,conv7])

我们把它们堆在最下面了!

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, None, None, 64)    1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, None, None, 64)    36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, None, None, 64)    0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, None, None, 128)   73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, None, None, 128)   147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, None, None, 128)   0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, None, None, 256)   295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, None, None, 256)   0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, None, None, 512)   1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, None, None, 512)   0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, None, None, 512)   0         
_________________________________________________________________
conv6 (Conv2D)               (None, None, None, 1024)  4719616   
_________________________________________________________________
conv7 (Conv2D)               (None, None, None, 1024)  1049600   
=================================================================
Total params: 20,483,904
Trainable params: 20,483,904
Non-trainable params: 0

相关问题