forked from mrq/DL-Art-School
Add corruptor_usage_probability
Governs how often a corruptor is used, vs feeding uncorrupted images.
This commit is contained in:
parent
635c53475f
commit
f911ef0d3e
|
@ -93,6 +93,7 @@ class SRGANModel(BaseModel):
|
||||||
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
|
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
|
||||||
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
|
self.D_noise_theta_floor = train_opt['D_noise_theta_floor'] if train_opt['D_noise_theta_floor'] else 0
|
||||||
self.corruptor_swapout_steps = train_opt['corruptor_swapout_steps'] if train_opt['corruptor_swapout_steps'] else 500
|
self.corruptor_swapout_steps = train_opt['corruptor_swapout_steps'] if train_opt['corruptor_swapout_steps'] else 500
|
||||||
|
self.corruptor_usage_prob = train_opt['corruptor_usage_probability'] if train_opt['corruptor_usage_probability'] else .5
|
||||||
|
|
||||||
# optimizers
|
# optimizers
|
||||||
# G
|
# G
|
||||||
|
@ -167,7 +168,7 @@ class SRGANModel(BaseModel):
|
||||||
def feed_data(self, data, need_GT=True):
|
def feed_data(self, data, need_GT=True):
|
||||||
# Corrupt the data with the given corruptor, if specified.
|
# Corrupt the data with the given corruptor, if specified.
|
||||||
self.fed_LQ = data['LQ'].to(self.device)
|
self.fed_LQ = data['LQ'].to(self.device)
|
||||||
if self.netC:
|
if self.netC and random.random() < self.corruptor_usage_prob:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
corrupted_L = self.netC(self.fed_LQ)[0].detach()
|
corrupted_L = self.netC(self.fed_LQ)[0].detach()
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user