From 3561cc164d086cc1786306f4ed0a813922ab36b9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 3 Oct 2020 11:19:20 -0600 Subject: [PATCH] Fix up fea_loss calculator (for validation) Not sure how this was working in regular training mode, but it was failing in DDP. --- codes/models/ExtensibleTrainer.py | 4 ++-- codes/models/SRGAN_model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 86d7e1fd..d17efcf4 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -240,8 +240,8 @@ class ExtensibleTrainer(BaseModel): def compute_fea_loss(self, real, fake): with torch.no_grad(): - logits_real = self.netF(real) - logits_fake = self.netF(fake) + logits_real = self.netF(real.to(self.device)) + logits_fake = self.netF(fake.to(self.device)) return nn.L1Loss().to(self.device)(logits_fake, logits_real) def test(self): diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 3af72ecd..14b2a461 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -883,7 +883,7 @@ class SRGANModel(BaseModel): with torch.no_grad(): real = real.unsqueeze(dim=0).to(self.device) fake = fake.unsqueeze(dim=0).to(self.device) - real_fea = self.netF(real).detach() + real_fea = self.netF(real) fake_fea = self.netF(fake) return self.cri_fea(fake_fea, real_fea).item()