From bcebed19b7d0c70e129c7e4e1242bf87e5892627 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 31 Jul 2020 16:38:14 -0600 Subject: [PATCH] Fix pixdisc bugs --- codes/models/SRGAN_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 114c3546..cce3f8be 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -455,13 +455,13 @@ class SRGANModel(BaseModel): with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() if 'pixgan' in self.opt['train']['gan_type']: + pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() + disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) + b, _, w, h = var_ref.shape + real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) + fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device) if not self.disjoint_data: # randomly determine portions of the image to swap to keep the discriminator honest. - pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() - disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) - b, _, w, h = var_ref.shape - real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) - fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device) SWAP_MAX_DIM = w // 4 SWAP_MIN_DIM = 16 assert SWAP_MAX_DIM > 0