forked from mrq/DL-Art-School
Fix multistep optimizer (feeding from wrong config params)
This commit is contained in:
parent
4bfbdaf94f
commit
ec2a795d53
|
@ -203,6 +203,7 @@ class SRGANModel(BaseModel):
|
|||
weight_decay=wd_D,
|
||||
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
|
||||
self.optimizers.append(self.optimizer_D)
|
||||
self.disc_optimizers.append(self.optimizer_D)
|
||||
|
||||
if self.spsr_enabled:
|
||||
# D_grad optimizer
|
||||
|
@ -219,6 +220,7 @@ class SRGANModel(BaseModel):
|
|||
weight_decay=wd_D,
|
||||
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
|
||||
self.optimizers.append(self.optimizer_D_grad)
|
||||
self.disc_optimizers.append(self.optimizer_D_grad)
|
||||
|
||||
if self.spsr_enabled:
|
||||
self.get_grad = ImageGradient().to(self.device)
|
||||
|
@ -253,9 +255,18 @@ class SRGANModel(BaseModel):
|
|||
|
||||
# schedulers
|
||||
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||
for optimizer in self.optimizers:
|
||||
# This is a recent change. assert to make sure any legacy configs dont find their way here.
|
||||
assert 'gen_lr_steps' in train_opt.keys() and 'disc_lr_steps' in train_opt.keys()
|
||||
self.schedulers.append(
|
||||
lr_scheduler.MultiStepLR_Restart(self.optimizer_G, train_opt['gen_lr_steps'],
|
||||
restarts=train_opt['restarts'],
|
||||
weights=train_opt['restart_weights'],
|
||||
gamma=train_opt['lr_gamma'],
|
||||
clear_state=train_opt['clear_state'],
|
||||
force_lr=train_opt['force_lr']))
|
||||
for o in self.disc_optimizers:
|
||||
self.schedulers.append(
|
||||
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
||||
lr_scheduler.MultiStepLR_Restart(o, train_opt['disc_lr_steps'],
|
||||
restarts=train_opt['restarts'],
|
||||
weights=train_opt['restart_weights'],
|
||||
gamma=train_opt['lr_gamma'],
|
||||
|
@ -267,11 +278,8 @@ class SRGANModel(BaseModel):
|
|||
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_G, train_opt['gen_lr_steps'],
|
||||
self.netG.module.get_progressive_starts(),
|
||||
train_opt['lr_gamma']))
|
||||
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'],
|
||||
[0],
|
||||
train_opt['lr_gamma']))
|
||||
if self.spsr_enabled:
|
||||
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D_grad, train_opt['disc_lr_steps'],
|
||||
for o in self.disc_optimizers:
|
||||
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(o, train_opt['disc_lr_steps'],
|
||||
[0],
|
||||
train_opt['lr_gamma']))
|
||||
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||
|
@ -295,7 +303,7 @@ class SRGANModel(BaseModel):
|
|||
# GAN LQ image params
|
||||
self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0
|
||||
|
||||
self.img_debug_steps = train_opt['img_debug_steps'] if train_opt['img_debug_steps'] else 50
|
||||
self.img_debug_steps = opt['logger']['img_debug_steps'] if 'img_debug_steps' in opt['logger'].keys() else 50
|
||||
|
||||
self.print_network() # print network
|
||||
self.load() # load G and D if needed
|
||||
|
|
|
@ -15,6 +15,7 @@ class BaseModel():
|
|||
self.is_train = opt['is_train']
|
||||
self.schedulers = []
|
||||
self.optimizers = []
|
||||
self.disc_optimizers = []
|
||||
|
||||
def feed_data(self, data):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue
Block a user