diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 3183d4d9..ad151c88 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -92,6 +92,7 @@ class SRGANModel(BaseModel): self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0 self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0 self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0 + self.corruptor_swapout_steps = train_opt['corruptor_swapout_steps'] if train_opt['corruptor_swapout_steps'] else 500 # optimizers # G @@ -161,6 +162,7 @@ class SRGANModel(BaseModel): self.print_network() # print network self.load() # load G and D if needed + self.load_random_corruptor() def feed_data(self, data, need_GT=True): # Corrupt the data with the given corruptor, if specified. @@ -346,6 +348,9 @@ class SRGANModel(BaseModel): self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) self.add_log_entry('noise_theta', noise_theta) + if step % self.corruptor_swapout_steps == 0 and step > 0: + self.load_random_corruptor() + # Allows the log to serve as an easy-to-use rotating buffer. def add_log_entry(self, key, value): key_it = "%s_it" % (key,) @@ -377,13 +382,13 @@ class SRGANModel(BaseModel): self.swapout_D_duration -= 1 if self.swapout_D_duration == 0: # Swap back. - print("Swapping back to current D model: %s" % (self.stashed_D,)) + logger.info("Swapping back to current D model: %s" % (self.stashed_D,)) self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load']) self.stashed_D = None elif self.swapout_D_freq != 0 and step % self.swapout_D_freq == 0: swapped_model = self.pick_rand_prev_model('D') if swapped_model is not None: - print("Swapping to previous D model: %s" % (swapped_model,)) + logger.info("Swapping to previous D model: %s" % (swapped_model,)) self.stashed_D = self.save_network(self.netD, 'D', 'swap_model') self.load_network(swapped_model, self.netD, self.opt['path']['strict_load']) self.swapout_D_duration = self.swapout_duration @@ -478,6 +483,14 @@ class SRGANModel(BaseModel): logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) + def load_random_corruptor(self): + if self.netC is None: + return + corruptor_files = glob.glob(os.path.join(self.opt['path']['pretrained_corruptors_dir'], "*.pth")) + corruptor_to_load = corruptor_files[random.randint(0, len(corruptor_files)-1)] + logger.info('Swapping corruptor to: %s' % (corruptor_to_load,)) + self.load_network(corruptor_to_load, self.netC, self.opt['path']['strict_load']) + def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step)