diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index c722161e..c729624f 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -339,7 +339,7 @@ class SRGANModel(BaseModel): l_d_fake_scaled.backward() if self.opt['train']['gan_type'] == 'pixgan': # randomly determine portions of the image to swap to keep the discriminator honest. - pixdisc_channels, pixdisc_output_reduction = self.netD.pixgan_parameters() + pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() disc_output_shape = (var_ref[0].shape[0], pixdisc_channels, var_ref[0].shape[2] // pixdisc_output_reduction, var_ref[0].shape[3] // pixdisc_output_reduction) b, _, w, h = var_ref[0].shape real = torch.ones((b, pixdisc_channels, w, h), device=var_ref[0].device)