Don't recompute generator outputs for D in standard operation
Should significantly improve training performance with negligible results differences.
This commit is contained in:
parent
11b227edfc
commit
4bfbdaf94f
|
@ -47,7 +47,7 @@ if __name__ == '__main__':
|
|||
im = rgb2ycbcr(train_data['GT'].double())
|
||||
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(),
|
||||
size=im.shape[2:],
|
||||
mode="bicubic"))
|
||||
mode="bicubic", align_corners=False))
|
||||
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)
|
||||
patches_hr = dct_2d(patches_hr, norm='ortho')
|
||||
patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True)
|
||||
|
|
|
@ -506,17 +506,16 @@ class SRGANModel(BaseModel):
|
|||
noise.to(self.device)
|
||||
real_disc_images = []
|
||||
fake_disc_images = []
|
||||
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
|
||||
for fake_GenOut, var_LGAN, var_H, var_ref, pix in zip(self.fake_GenOut, self.gan_img, self.var_H, self.var_ref, self.pix):
|
||||
if random.random() > self.gan_lq_img_use_prob:
|
||||
gen_input = var_L
|
||||
fake_H = fake_GenOut.clone().detach().requires_grad_(False)
|
||||
else:
|
||||
gen_input = var_LGAN
|
||||
# Re-compute generator outputs (post-update).
|
||||
# Re-compute generator outputs with the GAN inputs.
|
||||
with torch.no_grad():
|
||||
if self.spsr_enabled:
|
||||
_, fake_H, _ = self.netG(gen_input)
|
||||
_, fake_H, _ = self.netG(var_LGAN)
|
||||
else:
|
||||
_, fake_H = self.netG(gen_input)
|
||||
_, fake_H = self.netG(var_LGAN)
|
||||
fake_H = fake_H.detach()
|
||||
|
||||
if _profile:
|
||||
|
@ -583,8 +582,8 @@ class SRGANModel(BaseModel):
|
|||
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
||||
|
||||
# Interpolate down to the dimensionality that the discriminator uses.
|
||||
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear")
|
||||
fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear")
|
||||
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear", align_corners=False)
|
||||
fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear", align_corners=False)
|
||||
|
||||
# We're also assuming that this is exactly how the flattened discriminator output is generated.
|
||||
real = real.view(-1, 1)
|
||||
|
|
|
@ -266,7 +266,7 @@ class GrowingUnetDiscBase(nn.Module):
|
|||
disc_age = self.latest_step - self.progressive_schedule[i]
|
||||
fade_in = min(1, disc_age * self.growth_fade_in_per_step)
|
||||
mean_weight += fade_in
|
||||
base_loss += F.interpolate(l, size=res, mode="bilinear") * fade_in
|
||||
base_loss += F.interpolate(l, size=res, mode="bilinear", align_corners=False) * fade_in
|
||||
base_loss /= mean_weight
|
||||
|
||||
return base_loss.view(-1, 1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user