diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 83c428a5..433a2a2a 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -203,6 +203,7 @@ class SRGANModel(BaseModel): weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) + self.disc_optimizers.append(self.optimizer_D) if self.spsr_enabled: # D_grad optimizer @@ -219,6 +220,7 @@ class SRGANModel(BaseModel): weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D_grad) + self.disc_optimizers.append(self.optimizer_D_grad) if self.spsr_enabled: self.get_grad = ImageGradient().to(self.device) @@ -253,9 +255,18 @@ class SRGANModel(BaseModel): # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': - for optimizer in self.optimizers: + # This is a recent change. assert to make sure any legacy configs dont find their way here. + assert 'gen_lr_steps' in train_opt.keys() and 'disc_lr_steps' in train_opt.keys() + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(self.optimizer_G, train_opt['gen_lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'], + force_lr=train_opt['force_lr'])) + for o in self.disc_optimizers: self.schedulers.append( - lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + lr_scheduler.MultiStepLR_Restart(o, train_opt['disc_lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], @@ -267,11 +278,8 @@ class SRGANModel(BaseModel): self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_G, train_opt['gen_lr_steps'], self.netG.module.get_progressive_starts(), train_opt['lr_gamma'])) - self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'], - [0], - train_opt['lr_gamma'])) - if self.spsr_enabled: - self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D_grad, train_opt['disc_lr_steps'], + for o in self.disc_optimizers: + self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(o, train_opt['disc_lr_steps'], [0], train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': @@ -295,7 +303,7 @@ class SRGANModel(BaseModel): # GAN LQ image params self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0 - self.img_debug_steps = train_opt['img_debug_steps'] if train_opt['img_debug_steps'] else 50 + self.img_debug_steps = opt['logger']['img_debug_steps'] if 'img_debug_steps' in opt['logger'].keys() else 50 self.print_network() # print network self.load() # load G and D if needed diff --git a/codes/models/base_model.py b/codes/models/base_model.py index d011dca1..aeec4a65 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -15,6 +15,7 @@ class BaseModel(): self.is_train = opt['is_train'] self.schedulers = [] self.optimizers = [] + self.disc_optimizers = [] def feed_data(self, data): pass