Fix fixed_disc DataParallel issue

This commit is contained in:
James Betker 2020-07-31 16:59:23 -06:00
parent 8dd44182e6
commit a66fbb32b6

View File

@ -364,7 +364,7 @@ class SRGANModel(BaseModel):
l_g_fix_disc = torch.zeros(1, requires_grad=False).squeeze()
for fixed_disc in self.fixed_disc_nets:
weight = fixed_disc.fdisc_weight
weight = fixed_disc.module.fdisc_weight
real_fea = fixed_disc(pix).detach()
fake_fea = fixed_disc(fea_GenOut)
l_g_fix_disc += weight * self.cri_fea(fake_fea, real_fea)