Fix multistep optimizer (feeding from wrong config params)

This commit is contained in:
James Betker 2020-08-04 16:42:58 -06:00
parent 4bfbdaf94f
commit ec2a795d53
2 changed files with 17 additions and 8 deletions

View File

@ -203,6 +203,7 @@ class SRGANModel(BaseModel):
weight_decay=wd_D,
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
self.optimizers.append(self.optimizer_D)
self.disc_optimizers.append(self.optimizer_D)
if self.spsr_enabled:
# D_grad optimizer
@ -219,6 +220,7 @@ class SRGANModel(BaseModel):
weight_decay=wd_D,
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
self.optimizers.append(self.optimizer_D_grad)
self.disc_optimizers.append(self.optimizer_D_grad)
if self.spsr_enabled:
self.get_grad = ImageGradient().to(self.device)
@ -253,9 +255,18 @@ class SRGANModel(BaseModel):
# schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
# This is a recent change. assert to make sure any legacy configs dont find their way here.
assert 'gen_lr_steps' in train_opt.keys() and 'disc_lr_steps' in train_opt.keys()
self.schedulers.append(
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
lr_scheduler.MultiStepLR_Restart(self.optimizer_G, train_opt['gen_lr_steps'],
restarts=train_opt['restarts'],
weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state'],
force_lr=train_opt['force_lr']))
for o in self.disc_optimizers:
self.schedulers.append(
lr_scheduler.MultiStepLR_Restart(o, train_opt['disc_lr_steps'],
restarts=train_opt['restarts'],
weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'],
@ -267,11 +278,8 @@ class SRGANModel(BaseModel):
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_G, train_opt['gen_lr_steps'],
self.netG.module.get_progressive_starts(),
train_opt['lr_gamma']))
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'],
[0],
train_opt['lr_gamma']))
if self.spsr_enabled:
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D_grad, train_opt['disc_lr_steps'],
for o in self.disc_optimizers:
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(o, train_opt['disc_lr_steps'],
[0],
train_opt['lr_gamma']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
@ -295,7 +303,7 @@ class SRGANModel(BaseModel):
# GAN LQ image params
self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0
self.img_debug_steps = train_opt['img_debug_steps'] if train_opt['img_debug_steps'] else 50
self.img_debug_steps = opt['logger']['img_debug_steps'] if 'img_debug_steps' in opt['logger'].keys() else 50
self.print_network() # print network
self.load() # load G and D if needed

View File

@ -15,6 +15,7 @@ class BaseModel():
self.is_train = opt['is_train']
self.schedulers = []
self.optimizers = []
self.disc_optimizers = []
def feed_data(self, data):
pass