Don't recompute generator outputs for D in standard operation
Should significantly improve training performance with negligible results differences.
This commit is contained in:
parent
11b227edfc
commit
4bfbdaf94f
|
@ -47,7 +47,7 @@ if __name__ == '__main__':
|
||||||
im = rgb2ycbcr(train_data['GT'].double())
|
im = rgb2ycbcr(train_data['GT'].double())
|
||||||
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(),
|
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(),
|
||||||
size=im.shape[2:],
|
size=im.shape[2:],
|
||||||
mode="bicubic"))
|
mode="bicubic", align_corners=False))
|
||||||
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
|
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
|
||||||
patches_hr = dct_2d(patches_hr, norm='ortho')
|
patches_hr = dct_2d(patches_hr, norm='ortho')
|
||||||
patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True)
|
patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True)
|
||||||
|
|
|
@ -506,22 +506,21 @@ class SRGANModel(BaseModel):
|
||||||
noise.to(self.device)
|
noise.to(self.device)
|
||||||
real_disc_images = []
|
real_disc_images = []
|
||||||
fake_disc_images = []
|
fake_disc_images = []
|
||||||
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
|
for fake_GenOut, var_LGAN, var_H, var_ref, pix in zip(self.fake_GenOut, self.gan_img, self.var_H, self.var_ref, self.pix):
|
||||||
if random.random() > self.gan_lq_img_use_prob:
|
if random.random() > self.gan_lq_img_use_prob:
|
||||||
gen_input = var_L
|
fake_H = fake_GenOut.clone().detach().requires_grad_(False)
|
||||||
else:
|
else:
|
||||||
gen_input = var_LGAN
|
# Re-compute generator outputs with the GAN inputs.
|
||||||
# Re-compute generator outputs (post-update).
|
with torch.no_grad():
|
||||||
with torch.no_grad():
|
if self.spsr_enabled:
|
||||||
if self.spsr_enabled:
|
_, fake_H, _ = self.netG(var_LGAN)
|
||||||
_, fake_H, _ = self.netG(gen_input)
|
else:
|
||||||
else:
|
_, fake_H = self.netG(var_LGAN)
|
||||||
_, fake_H = self.netG(gen_input)
|
fake_H = fake_H.detach()
|
||||||
fake_H = fake_H.detach()
|
|
||||||
|
|
||||||
if _profile:
|
if _profile:
|
||||||
print("Gen forward for disc %f" % (time() - _t,))
|
print("Gen forward for disc %f" % (time() - _t,))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
# Apply noise to the inputs to slow discriminator convergence.
|
# Apply noise to the inputs to slow discriminator convergence.
|
||||||
var_ref = var_ref + noise
|
var_ref = var_ref + noise
|
||||||
|
@ -583,8 +582,8 @@ class SRGANModel(BaseModel):
|
||||||
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
||||||
|
|
||||||
# Interpolate down to the dimensionality that the discriminator uses.
|
# Interpolate down to the dimensionality that the discriminator uses.
|
||||||
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear")
|
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear", align_corners=False)
|
||||||
fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear")
|
fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear", align_corners=False)
|
||||||
|
|
||||||
# We're also assuming that this is exactly how the flattened discriminator output is generated.
|
# We're also assuming that this is exactly how the flattened discriminator output is generated.
|
||||||
real = real.view(-1, 1)
|
real = real.view(-1, 1)
|
||||||
|
|
|
@ -266,7 +266,7 @@ class GrowingUnetDiscBase(nn.Module):
|
||||||
disc_age = self.latest_step - self.progressive_schedule[i]
|
disc_age = self.latest_step - self.progressive_schedule[i]
|
||||||
fade_in = min(1, disc_age * self.growth_fade_in_per_step)
|
fade_in = min(1, disc_age * self.growth_fade_in_per_step)
|
||||||
mean_weight += fade_in
|
mean_weight += fade_in
|
||||||
base_loss += F.interpolate(l, size=res, mode="bilinear") * fade_in
|
base_loss += F.interpolate(l, size=res, mode="bilinear", align_corners=False) * fade_in
|
||||||
base_loss /= mean_weight
|
base_loss /= mean_weight
|
||||||
|
|
||||||
return base_loss.view(-1, 1)
|
return base_loss.view(-1, 1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user