Add corruptor_usage_probability

Governs how often a corruptor is used, vs feeding uncorrupted images.
This commit is contained in:
James Betker 2020-05-16 09:05:43 -06:00
parent 635c53475f
commit f911ef0d3e

View File

@ -93,6 +93,7 @@ class SRGANModel(BaseModel):
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
self.corruptor_swapout_steps = train_opt['corruptor_swapout_steps'] if train_opt['corruptor_swapout_steps'] else 500
self.corruptor_usage_prob = train_opt['corruptor_usage_probability'] if train_opt['corruptor_usage_probability'] else .5
# optimizers
# G
@ -167,7 +168,7 @@ class SRGANModel(BaseModel):
def feed_data(self, data, need_GT=True):
# Corrupt the data with the given corruptor, if specified.
self.fed_LQ = data['LQ'].to(self.device)
if self.netC:
if self.netC and random.random() < self.corruptor_usage_prob:
with torch.no_grad():
corrupted_L = self.netC(self.fed_LQ)[0].detach()
else: