我正在尝试使用this github repo,我可以用我自己的数据集运行训练,但我不能找到如何加载和保存权重?
这是我的代码:
from ISR.models import RRDN
from ISR.models import Discriminator
from ISR.models import Cut_VGG19
from keras.callbacks import ModelCheckpoint
lr_train_patch_size = 40
layers_to_extract = [5, 9]
scale = 2
hr_train_patch_size = lr_train_patch_size * scale
rrdn = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)
f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)
from ISR.train import Trainer
loss_weights = {
'generator': 0.0,
'feature_extractor': 0.0833,
'discriminator': 0.01
}
losses = {
'generator': 'mae',
'feature_extractor': 'mse',
'discriminator': 'binary_crossentropy'
}
log_dirs = {'logs': './logs', 'weights': './weights'}
learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}
flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}
trainer = Trainer(
generator=rrdn,
discriminator=discr,
feature_extractor=f_ext,
lr_train_dir='lrtrain',
hr_train_dir='hrtrain',
lr_valid_dir='lrval',
hr_valid_dir='hrval',
loss_weights=loss_weights,
learning_rate=learning_rate,
flatness=flatness,
dataname='image_dataset',
log_dirs=log_dirs,
weights_generator=None,
weights_discriminator=None,
n_validation=40,
)
trainer.train(
epochs=80,
steps_per_epoch=100,
batch_size=16,
monitored_metrics={'val_PSNR_Y': 'max'},
)
1条答案
按热度按时间fslejnso1#
我可能不知道答案,但看了这一行之后:
阅读了一些文档之后,我想:
我不太确定如何保存更新的权重,但可以尝试:
我个人的观点是,你应该找到一种方法来保存整个模型,而不是分别保存生成器和鉴别器的权重。我希望我能帮上一点忙。