diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 411cc13a..d88fa46c 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -554,8 +554,8 @@ class SRGANModel(BaseModel): def compute_fea_loss(self, real, fake): with torch.no_grad(): - real = real.unsqueeze(dim=0) - fake = fake.unsqueeze(dim=0) + real = real.unsqueeze(dim=0).to(self.device) + fake = fake.unsqueeze(dim=0).to(self.device) real_fea = self.netF(real).detach() fake_fea = self.netF(fake) return self.cri_fea(fake_fea, real_fea).item()