From f911ef0d3e9a122c21d4aa701b12128d087733f7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 16 May 2020 09:05:43 -0600 Subject: [PATCH] Add corruptor_usage_probability Governs how often a corruptor is used, vs feeding uncorrupted images. --- codes/models/SRGAN_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 5634b59a..ba0f3ade 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -93,6 +93,7 @@ class SRGANModel(BaseModel): self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0 self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0 self.corruptor_swapout_steps = train_opt['corruptor_swapout_steps'] if train_opt['corruptor_swapout_steps'] else 500 + self.corruptor_usage_prob = train_opt['corruptor_usage_probability'] if train_opt['corruptor_usage_probability'] else .5 # optimizers # G @@ -167,7 +168,7 @@ class SRGANModel(BaseModel): def feed_data(self, data, need_GT=True): # Corrupt the data with the given corruptor, if specified. self.fed_LQ = data['LQ'].to(self.device) - if self.netC: + if self.netC and random.random() < self.corruptor_usage_prob: with torch.no_grad(): corrupted_L = self.netC(self.fed_LQ)[0].detach() else: