forked from mrq/DL-Art-School
Improve corruptor logic: switch corruptors randomly
This commit is contained in:
parent
d72e154442
commit
61ed51d9e4
|
@ -92,6 +92,7 @@ class SRGANModel(BaseModel):
|
||||||
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
|
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
|
||||||
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
|
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
|
||||||
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
|
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
|
||||||
|
self.corruptor_swapout_steps = train_opt['corruptor_swapout_steps'] if train_opt['corruptor_swapout_steps'] else 500
|
||||||
|
|
||||||
# optimizers
|
# optimizers
|
||||||
# G
|
# G
|
||||||
|
@ -161,6 +162,7 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
self.print_network() # print network
|
self.print_network() # print network
|
||||||
self.load() # load G and D if needed
|
self.load() # load G and D if needed
|
||||||
|
self.load_random_corruptor()
|
||||||
|
|
||||||
def feed_data(self, data, need_GT=True):
|
def feed_data(self, data, need_GT=True):
|
||||||
# Corrupt the data with the given corruptor, if specified.
|
# Corrupt the data with the given corruptor, if specified.
|
||||||
|
@ -346,6 +348,9 @@ class SRGANModel(BaseModel):
|
||||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||||
self.add_log_entry('noise_theta', noise_theta)
|
self.add_log_entry('noise_theta', noise_theta)
|
||||||
|
|
||||||
|
if step % self.corruptor_swapout_steps == 0 and step > 0:
|
||||||
|
self.load_random_corruptor()
|
||||||
|
|
||||||
# Allows the log to serve as an easy-to-use rotating buffer.
|
# Allows the log to serve as an easy-to-use rotating buffer.
|
||||||
def add_log_entry(self, key, value):
|
def add_log_entry(self, key, value):
|
||||||
key_it = "%s_it" % (key,)
|
key_it = "%s_it" % (key,)
|
||||||
|
@ -377,13 +382,13 @@ class SRGANModel(BaseModel):
|
||||||
self.swapout_D_duration -= 1
|
self.swapout_D_duration -= 1
|
||||||
if self.swapout_D_duration == 0:
|
if self.swapout_D_duration == 0:
|
||||||
# Swap back.
|
# Swap back.
|
||||||
print("Swapping back to current D model: %s" % (self.stashed_D,))
|
logger.info("Swapping back to current D model: %s" % (self.stashed_D,))
|
||||||
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
|
self.load_network(self.stashed_D, self.netD, self.opt['path']['strict_load'])
|
||||||
self.stashed_D = None
|
self.stashed_D = None
|
||||||
elif self.swapout_D_freq != 0 and step % self.swapout_D_freq == 0:
|
elif self.swapout_D_freq != 0 and step % self.swapout_D_freq == 0:
|
||||||
swapped_model = self.pick_rand_prev_model('D')
|
swapped_model = self.pick_rand_prev_model('D')
|
||||||
if swapped_model is not None:
|
if swapped_model is not None:
|
||||||
print("Swapping to previous D model: %s" % (swapped_model,))
|
logger.info("Swapping to previous D model: %s" % (swapped_model,))
|
||||||
self.stashed_D = self.save_network(self.netD, 'D', 'swap_model')
|
self.stashed_D = self.save_network(self.netD, 'D', 'swap_model')
|
||||||
self.load_network(swapped_model, self.netD, self.opt['path']['strict_load'])
|
self.load_network(swapped_model, self.netD, self.opt['path']['strict_load'])
|
||||||
self.swapout_D_duration = self.swapout_duration
|
self.swapout_D_duration = self.swapout_duration
|
||||||
|
@ -478,6 +483,14 @@ class SRGANModel(BaseModel):
|
||||||
logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
|
logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
|
||||||
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
|
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
|
||||||
|
|
||||||
|
def load_random_corruptor(self):
|
||||||
|
if self.netC is None:
|
||||||
|
return
|
||||||
|
corruptor_files = glob.glob(os.path.join(self.opt['path']['pretrained_corruptors_dir'], "*.pth"))
|
||||||
|
corruptor_to_load = corruptor_files[random.randint(0, len(corruptor_files)-1)]
|
||||||
|
logger.info('Swapping corruptor to: %s' % (corruptor_to_load,))
|
||||||
|
self.load_network(corruptor_to_load, self.netC, self.opt['path']['strict_load'])
|
||||||
|
|
||||||
def save(self, iter_step):
|
def save(self, iter_step):
|
||||||
self.save_network(self.netG, 'G', iter_step)
|
self.save_network(self.netG, 'G', iter_step)
|
||||||
self.save_network(self.netD, 'D', iter_step)
|
self.save_network(self.netD, 'D', iter_step)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user