tensorflow 如何使用图像超分辨率存储库加载和保存模型?

b4lqfgs4  于 2023-05-18  发布在  其他
关注(0)|答案(1)|浏览(216)

我正在尝试使用this github repo,我可以用我自己的数据集运行训练,但我不能找到如何加载和保存权重?
这是我的代码:

from ISR.models import RRDN
from ISR.models import Discriminator
from ISR.models import Cut_VGG19
from keras.callbacks import ModelCheckpoint

lr_train_patch_size = 40
layers_to_extract = [5, 9]
scale = 2
hr_train_patch_size = lr_train_patch_size * scale

rrdn  = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)

f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)


from ISR.train import Trainer
loss_weights = {
  'generator': 0.0,
  'feature_extractor': 0.0833,
  'discriminator': 0.01
}
losses = {
  'generator': 'mae',
  'feature_extractor': 'mse',
  'discriminator': 'binary_crossentropy'
}

log_dirs = {'logs': './logs', 'weights': './weights'}

learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}

flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}

trainer = Trainer(
    generator=rrdn,
    discriminator=discr,
    feature_extractor=f_ext,
    lr_train_dir='lrtrain',
    hr_train_dir='hrtrain',
    lr_valid_dir='lrval',
    hr_valid_dir='hrval',
    loss_weights=loss_weights,
    learning_rate=learning_rate,
    flatness=flatness,
    dataname='image_dataset',
    log_dirs=log_dirs,
    weights_generator=None,
    weights_discriminator=None,
    n_validation=40,
)


trainer.train(
    epochs=80,
    steps_per_epoch=100,
    batch_size=16,
    monitored_metrics={'val_PSNR_Y': 'max'},
)
fslejnso

fslejnso1#

我可能不知道答案,但看了这一行之后:

trainer = Trainer(
    generator=rrdn,
    discriminator=discr,
    feature_extractor=f_ext,
    lr_train_dir='lrtrain',
    hr_train_dir='hrtrain',
    lr_valid_dir='lrval',
    hr_valid_dir='hrval',
    loss_weights=loss_weights,
    learning_rate=learning_rate,
    flatness=flatness,
    dataname='image_dataset',
    log_dirs=log_dirs,
    weights_generator=None,
    weights_discriminator=None,
    n_validation=40,
)

阅读了一些文档之后,我想:

1- weights_generator and weights_discriminator are set to None in the trainer
2- you call Trainer.train() on trainer
2.5- weights_generator and weights_discriminator in the training process
3- TODO: save the updated weights_generator and weights_discriminator

我不太确定如何保存更新的权重,但可以尝试:

gener_weights = trainer.weights_generator
discrim_weights = trainer.weights_discriminator
#TODO: figure out the type of the above variables
#TRY: Maybe they would be some sort of a NumPy array or a tensor, idk!

#TODO: Then find a way to save it on your local machine
#TRY: Pickling them using the pickle library, for future usage.

我个人的观点是,你应该找到一种方法来保存整个模型,而不是分别保存生成器和鉴别器的权重。我希望我能帮上一点忙。

相关问题