More torch 1.6 fixes
This commit is contained in:
parent
a66fbb32b6
commit
c139f5cd17
|
@ -362,12 +362,12 @@ class SRGANModel(BaseModel):
|
||||||
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
|
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
|
||||||
# it should target this
|
# it should target this
|
||||||
|
|
||||||
l_g_fix_disc = torch.zeros(1, requires_grad=False).squeeze()
|
l_g_fix_disc = torch.zeros(1, requires_grad=False, device=self.device).squeeze()
|
||||||
for fixed_disc in self.fixed_disc_nets:
|
for fixed_disc in self.fixed_disc_nets:
|
||||||
weight = fixed_disc.module.fdisc_weight
|
weight = fixed_disc.module.fdisc_weight
|
||||||
real_fea = fixed_disc(pix).detach()
|
real_fea = fixed_disc(pix).detach()
|
||||||
fake_fea = fixed_disc(fea_GenOut)
|
fake_fea = fixed_disc(fea_GenOut)
|
||||||
l_g_fix_disc += weight * self.cri_fea(fake_fea, real_fea)
|
l_g_fix_disc = l_g_fix_disc + weight * self.cri_fea(fake_fea, real_fea)
|
||||||
l_g_total += l_g_fix_disc
|
l_g_total += l_g_fix_disc
|
||||||
|
|
||||||
if self.l_gan_w > 0:
|
if self.l_gan_w > 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user