Keras在加载预训练权重后对不同结果建模

myss37ts  于 2022-11-24  发布在  其他
关注(0)|答案(1)|浏览(173)

在成功训练了一个模型,并保存了检查点的权重之后,当我用load_weights函数重新加载权重并运行evaluate时,我得到的结果就像网络加载了原始权重一样。我尝试在训练集和有效集上运行evaluate,以排除测试集的问题,同样的事情发生了。
下面是培训代码:

def create_inception(model_name, fold_path, model_path, optimizer=Adam(learning_rate=0.0001)):
  inputs = tf.keras.Input(shape=(224, 224, 3))
  head_model = InceptionV3(weights = 'imagenet', include_top = False, input_shape = (224,224,3))

  head_model.trainable = True

  head_model = head_model(inputs, training = True)
  head_model = tf.keras.layers.Flatten()(head_model)
  head_model = tf.keras.layers.Dense(256, activation='relu')(head_model)

  output = Dense(3, activation='softmax')(head_model)
  model4 = Model(inputs=inputs, outputs = output)

  train_datagen = ImageDataGenerator(
      rescale=1./255,
      rotation_range=40,
      width_shift_range=0.2,
      height_shift_range=0.2,
      shear_range=0.2,
      zoom_range=0.2,
      horizontal_flip=True,
      fill_mode='nearest'
  )

  validation_datagen  = ImageDataGenerator(rescale=1./255)

  # Note that the validation data should not be augmented!
  train_generator = train_datagen.flow_from_directory(fold_path + '/Train',
                                                      batch_size=8,
                                                      class_mode='categorical',
                                                      target_size=(224, 224))     

  validation_generator =  validation_datagen.flow_from_directory(fold_path + '/Valid',
                                                          batch_size=8,
                                                          class_mode  = 'categorical',
                                                          target_size = (224, 224))

  # compilamos el modelo y lo entrenamos
  model4.compile(loss="categorical_crossentropy", 
                optimizer=optimizer,
                metrics=[tfa.metrics.F1Score(num_classes=3, average='micro'), 'accuracy'])
  
  return model4, train_generator, validation_generator

def train_inception_model(model, train_generator, validation_generator, model_path, model_name, epochs=100):
  batch_size = 8
  steps_per_epoch = train_generator.n // batch_size
  validation_steps = validation_generator.n // batch_size

  # generamos un monitor para el earlystop cuando el modelo este entrenado
  early_stop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=10, min_delta=0.001)
  # generamos el callback de guardado del modelo
  filepath = model_path + model_name + "_best.hdf5"
  checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_f1_score', verbose=1, save_best_only=True, save_weights_only=True, mode='max')

  model4_history = model.fit(
      train_generator,
      steps_per_epoch = steps_per_epoch,
      epochs = epochs,
      callbacks = [early_stop, checkpoint], 
      validation_data = validation_generator,
      validation_steps = validation_steps
  )
  return model, model4_history

这是培训结果

Epoch 40: val_f1_score did not improve from 0.94413
99/99 [==============================] - 54s 540ms/step - loss: 0.0601 - f1_score: 0.9750 - accuracy: 0.9750 - val_loss: 0.1700 - val_f1_score: 0.9347 - val_accuracy: 0.9347

评估代码:

name = 'kfold_model_' + str(0)
    print(name)
    model4, _, _ =  create_inception(name, '/content/Fold10', model_path4, opt['type'](learning_rate=lr))
    model4.compile(loss="categorical_crossentropy", 
                  optimizer=Adam(learning_rate=0.01),
                  metrics=[tfa.metrics.F1Score(num_classes=3, average='micro'), 'accuracy'])
    model4.load_weights(model_path4 + "{}_best.hdf5".format(name))
    test_datagen  = ImageDataGenerator(rescale=1./255)
    test_generator =  test_datagen.flow_from_directory('Fold{}/Test'.format(10),
                                                          batch_size=32,
                                                          class_mode  = 'categorical',
                                                          target_size = (224, 224), shuffle=False)

    test_lost, test_f1, test_acc = model4.evaluate(test_generator)
    print ("Test f1:", test_f1)
    print ("Test Accuracy:", test_acc)

评估输出:

Found 530 images belonging to 3 classes.
17/17 [==============================] - 45s 2s/step - loss: 5.2010 - f1_score: 0.5491 - accuracy: 0.5491
Test f1: 0.5490565896034241
Test Accuracy: 0.5490565896034241

这段代码在VGG16网络上运行得非常好,但是在使用inceptionV3和resnet121时,我遇到了同样的问题。
有什么建议吗?

hkmswyz6

hkmswyz61#

如果我删除了测试数据生成器中的"Shuffe = False",它将运行得非常好。

相关问题