Don't recompute generator outputs for D in standard operation

Should significantly improve training performance with negligible
results differences.
This commit is contained in:
James Betker 2020-08-04 11:28:52 -06:00
parent 11b227edfc
commit 4bfbdaf94f
3 changed files with 16 additions and 17 deletions

View File

@ -47,7 +47,7 @@ if __name__ == '__main__':
im = rgb2ycbcr(train_data['GT'].double())
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(),
size=im.shape[2:],
mode="bicubic"))
mode="bicubic", align_corners=False))
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
patches_hr = dct_2d(patches_hr, norm='ortho')
patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True)

View File

@ -506,22 +506,21 @@ class SRGANModel(BaseModel):
noise.to(self.device)
real_disc_images = []
fake_disc_images = []
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
for fake_GenOut, var_LGAN, var_H, var_ref, pix in zip(self.fake_GenOut, self.gan_img, self.var_H, self.var_ref, self.pix):
if random.random() > self.gan_lq_img_use_prob:
gen_input = var_L
fake_H = fake_GenOut.clone().detach().requires_grad_(False)
else:
gen_input = var_LGAN
# Re-compute generator outputs (post-update).
with torch.no_grad():
if self.spsr_enabled:
_, fake_H, _ = self.netG(gen_input)
else:
_, fake_H = self.netG(gen_input)
fake_H = fake_H.detach()
# Re-compute generator outputs with the GAN inputs.
with torch.no_grad():
if self.spsr_enabled:
_, fake_H, _ = self.netG(var_LGAN)
else:
_, fake_H = self.netG(var_LGAN)
fake_H = fake_H.detach()
if _profile:
print("Gen forward for disc %f" % (time() - _t,))
_t = time()
if _profile:
print("Gen forward for disc %f" % (time() - _t,))
_t = time()
# Apply noise to the inputs to slow discriminator convergence.
var_ref = var_ref + noise
@ -583,8 +582,8 @@ class SRGANModel(BaseModel):
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
# Interpolate down to the dimensionality that the discriminator uses.
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear")
fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear")
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear", align_corners=False)
fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear", align_corners=False)
# We're also assuming that this is exactly how the flattened discriminator output is generated.
real = real.view(-1, 1)

View File

@ -266,7 +266,7 @@ class GrowingUnetDiscBase(nn.Module):
disc_age = self.latest_step - self.progressive_schedule[i]
fade_in = min(1, disc_age * self.growth_fade_in_per_step)
mean_weight += fade_in
base_loss += F.interpolate(l, size=res, mode="bilinear") * fade_in
base_loss += F.interpolate(l, size=res, mode="bilinear", align_corners=False) * fade_in
base_loss /= mean_weight
return base_loss.view(-1, 1)