Improve corruptor logic: switch corruptors randomly

This commit is contained in:
James Betker 2020-05-14 23:05:02 -06:00
parent d72e154442
commit 61ed51d9e4

View File

@ -92,6 +92,7 @@ class SRGANModel(BaseModel):
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
self.corruptor_swapout_steps = train_opt['corruptor_swapout_steps'] if train_opt['corruptor_swapout_steps'] else 500
# optimizers
# G
@ -161,6 +162,7 @@ class SRGANModel(BaseModel):
self.print_network() # print network
self.load() # load G and D if needed
self.load_random_corruptor()
def feed_data(self, data, need_GT=True):
# Corrupt the data with the given corruptor, if specified.
@ -346,6 +348,9 @@ class SRGANModel(BaseModel):
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
self.add_log_entry('noise_theta', noise_theta)
if step % self.corruptor_swapout_steps == 0 and step > 0:
self.load_random_corruptor()
# Allows the log to serve as an easy-to-use rotating buffer.
def add_log_entry(self, key, value):
key_it = "%s_it" % (key,)
@ -377,13 +382,13 @@ class SRGANModel(BaseModel):
self.swapout_D_duration -= 1
if self.swapout_D_duration == 0:
# Swap back.
print("Swapping back to current D model: %s" % (self.stashed_D,))
logger.info("Swapping back to current D model: %s" % (self.stashed_D,))
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
self.stashed_D = None
elif self.swapout_D_freq != 0 and step % self.swapout_D_freq == 0:
swapped_model = self.pick_rand_prev_model('D')
if swapped_model is not None:
print("Swapping to previous D model: %s" % (swapped_model,))
logger.info("Swapping to previous D model: %s" % (swapped_model,))
self.stashed_D = self.save_network(self.netD, 'D', 'swap_model')
self.load_network(swapped_model, self.netD, self.opt['path']['strict_load'])
self.swapout_D_duration = self.swapout_duration
@ -478,6 +483,14 @@ class SRGANModel(BaseModel):
logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
def load_random_corruptor(self):
if self.netC is None:
return
corruptor_files = glob.glob(os.path.join(self.opt['path']['pretrained_corruptors_dir'], "*.pth"))
corruptor_to_load = corruptor_files[random.randint(0, len(corruptor_files)-1)]
logger.info('Swapping corruptor to: %s' % (corruptor_to_load,))
self.load_network(corruptor_to_load, self.netC, self.opt['path']['strict_load'])
def save(self, iter_step):
self.save_network(self.netG, 'G', iter_step)
self.save_network(self.netD, 'D', iter_step)