More torch 1.6 fixes

This commit is contained in:
James Betker 2020-07-31 17:03:20 -06:00
parent a66fbb32b6
commit c139f5cd17

View File

@ -362,12 +362,12 @@ class SRGANModel(BaseModel):
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically, # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
# it should target this # it should target this
l_g_fix_disc = torch.zeros(1, requires_grad=False).squeeze() l_g_fix_disc = torch.zeros(1, requires_grad=False, device=self.device).squeeze()
for fixed_disc in self.fixed_disc_nets: for fixed_disc in self.fixed_disc_nets:
weight = fixed_disc.module.fdisc_weight weight = fixed_disc.module.fdisc_weight
real_fea = fixed_disc(pix).detach() real_fea = fixed_disc(pix).detach()
fake_fea = fixed_disc(fea_GenOut) fake_fea = fixed_disc(fea_GenOut)
l_g_fix_disc += weight * self.cri_fea(fake_fea, real_fea) l_g_fix_disc = l_g_fix_disc + weight * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fix_disc l_g_total += l_g_fix_disc
if self.l_gan_w > 0: if self.l_gan_w > 0: