This commit is contained in:
James Betker 2020-07-10 23:02:56 -06:00
parent 902527dfaa
commit ba6187859a

View File

@ -339,7 +339,7 @@ class SRGANModel(BaseModel):
l_d_fake_scaled.backward()
if self.opt['train']['gan_type'] == 'pixgan':
# randomly determine portions of the image to swap to keep the discriminator honest.
pixdisc_channels, pixdisc_output_reduction = self.netD.pixgan_parameters()
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
disc_output_shape = (var_ref[0].shape[0], pixdisc_channels, var_ref[0].shape[2] // pixdisc_output_reduction, var_ref[0].shape[3] // pixdisc_output_reduction)
b, _, w, h = var_ref[0].shape
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref[0].device)