From c139f5cd173c71bbe6e86ba302f96ad7323f15e9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 31 Jul 2020 17:03:20 -0600 Subject: [PATCH] More torch 1.6 fixes --- codes/models/SRGAN_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 73cdb27d..85000f40 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -362,12 +362,12 @@ class SRGANModel(BaseModel): # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically, # it should target this - l_g_fix_disc = torch.zeros(1, requires_grad=False).squeeze() + l_g_fix_disc = torch.zeros(1, requires_grad=False, device=self.device).squeeze() for fixed_disc in self.fixed_disc_nets: weight = fixed_disc.module.fdisc_weight real_fea = fixed_disc(pix).detach() fake_fea = fixed_disc(fea_GenOut) - l_g_fix_disc += weight * self.cri_fea(fake_fea, real_fea) + l_g_fix_disc = l_g_fix_disc + weight * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fix_disc if self.l_gan_w > 0: