Fix up fea_loss calculator (for validation)

Not sure how this was working in regular training mode, but it
was failing in DDP.
This commit is contained in:
James Betker 2020-10-03 11:19:20 -06:00
parent 21d3bb83b2
commit 3561cc164d
2 changed files with 3 additions and 3 deletions

View File

@ -240,8 +240,8 @@ class ExtensibleTrainer(BaseModel):
def compute_fea_loss(self, real, fake): def compute_fea_loss(self, real, fake):
with torch.no_grad(): with torch.no_grad():
logits_real = self.netF(real) logits_real = self.netF(real.to(self.device))
logits_fake = self.netF(fake) logits_fake = self.netF(fake.to(self.device))
return nn.L1Loss().to(self.device)(logits_fake, logits_real) return nn.L1Loss().to(self.device)(logits_fake, logits_real)
def test(self): def test(self):

View File

@ -883,7 +883,7 @@ class SRGANModel(BaseModel):
with torch.no_grad(): with torch.no_grad():
real = real.unsqueeze(dim=0).to(self.device) real = real.unsqueeze(dim=0).to(self.device)
fake = fake.unsqueeze(dim=0).to(self.device) fake = fake.unsqueeze(dim=0).to(self.device)
real_fea = self.netF(real).detach() real_fea = self.netF(real)
fake_fea = self.netF(fake) fake_fea = self.netF(fake)
return self.cri_fea(fake_fea, real_fea).item() return self.cri_fea(fake_fea, real_fea).item()