forked from mrq/DL-Art-School
Allow noise to be added to discriminator inputs
This commit is contained in:
parent
9210a62f58
commit
06d18343f7
|
@ -78,6 +78,8 @@ class SRGANModel(BaseModel):
|
||||||
# D_update_ratio and D_init_iters
|
# D_update_ratio and D_init_iters
|
||||||
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
||||||
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
||||||
|
self.D_noise_theta = train_opt['D_noise_theta_init'] if train_opt['D_noise_theta_init'] else 0
|
||||||
|
self.D_noise_final = train_opt['D_noise_final_it'] if train_opt['D_noise_final_it'] else 0
|
||||||
|
|
||||||
# optimizers
|
# optimizers
|
||||||
# G
|
# G
|
||||||
|
@ -165,7 +167,14 @@ class SRGANModel(BaseModel):
|
||||||
for p in self.netG.parameters():
|
for p in self.netG.parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
|
||||||
|
# Calculate a standard deviation for the gaussian noise to be applied to the discriminator, termed noise-theta.
|
||||||
|
if step >= self.D_noise_final:
|
||||||
|
noise_theta = 0
|
||||||
|
else:
|
||||||
|
noise_theta = self.D_noise_theta * (self.D_noise_final - step) / self.D_noise_final
|
||||||
|
|
||||||
self.fake_GenOut = []
|
self.fake_GenOut = []
|
||||||
|
var_ref_skips = []
|
||||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||||
fake_GenOut = self.netG(var_L)
|
fake_GenOut = self.netG(var_L)
|
||||||
|
|
||||||
|
@ -177,9 +186,11 @@ class SRGANModel(BaseModel):
|
||||||
self.fake_GenOut.append((fake_GenOut[0].detach(),
|
self.fake_GenOut.append((fake_GenOut[0].detach(),
|
||||||
fake_GenOut[1].detach(),
|
fake_GenOut[1].detach(),
|
||||||
fake_GenOut[2].detach()))
|
fake_GenOut[2].detach()))
|
||||||
|
var_ref = (var_ref,) + self.create_artificial_skips(var_H)
|
||||||
else:
|
else:
|
||||||
gen_img = fake_GenOut
|
gen_img = fake_GenOut
|
||||||
self.fake_GenOut.append(fake_GenOut.detach())
|
self.fake_GenOut.append(fake_GenOut.detach())
|
||||||
|
var_ref_skips.append(var_ref)
|
||||||
|
|
||||||
l_g_total = 0
|
l_g_total = 0
|
||||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||||
|
@ -219,17 +230,13 @@ class SRGANModel(BaseModel):
|
||||||
for p in self.netD.parameters():
|
for p in self.netD.parameters():
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
|
|
||||||
# Convert var_ref to have the same output format as the generator. This generally means interpolating the
|
noise = torch.randn_like(var_ref[0]) * noise_theta
|
||||||
# HR images to have the same output dimensions as each generator skip connection.
|
noise.to(self.device)
|
||||||
if isinstance(self.fake_GenOut[0], tuple):
|
|
||||||
var_ref_skips = []
|
|
||||||
for ref, hi_res in zip(self.var_ref, self.var_H):
|
|
||||||
var_ref_skips.append((ref,) + self.create_artificial_skips(hi_res))
|
|
||||||
else:
|
|
||||||
var_ref_skips = self.var_ref
|
|
||||||
|
|
||||||
self.optimizer_D.zero_grad()
|
self.optimizer_D.zero_grad()
|
||||||
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, var_ref_skips, self.pix, self.fake_GenOut):
|
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, var_ref_skips, self.pix, self.fake_GenOut):
|
||||||
|
# Apply noise to the inputs to slow discriminator convergence.
|
||||||
|
var_ref = (var_ref[0] + noise,) + var_ref[1:]
|
||||||
|
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
# need to forward and backward separately, since batch norm statistics differ
|
# need to forward and backward separately, since batch norm statistics differ
|
||||||
# real
|
# real
|
||||||
|
@ -297,6 +304,7 @@ class SRGANModel(BaseModel):
|
||||||
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor)
|
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor)
|
||||||
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor)
|
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor)
|
||||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||||
|
self.add_log_entry('noise_theta', noise_theta)
|
||||||
|
|
||||||
# Allows the log to serve as an easy-to-use rotating buffer.
|
# Allows the log to serve as an easy-to-use rotating buffer.
|
||||||
def add_log_entry(self, key, value):
|
def add_log_entry(self, key, value):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user