From 9ccf771629defff23218d5aae6a3c333bf4716b9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 23 Jul 2020 10:16:34 -0600 Subject: [PATCH] Fix feature validation, wrong device Only shows up in distributed training for some reason. --- codes/models/SRGAN_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 411cc13a..d88fa46c 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -554,8 +554,8 @@ class SRGANModel(BaseModel): def compute_fea_loss(self, real, fake): with torch.no_grad(): - real = real.unsqueeze(dim=0) - fake = fake.unsqueeze(dim=0) + real = real.unsqueeze(dim=0).to(self.device) + fake = fake.unsqueeze(dim=0).to(self.device) real_fea = self.netF(real).detach() fake_fea = self.netF(fake) return self.cri_fea(fake_fea, real_fea).item()