Fix up fea_loss calculator (for validation)
Not sure how this was working in regular training mode, but it was failing in DDP.
This commit is contained in:
parent
21d3bb83b2
commit
3561cc164d
|
@ -240,8 +240,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
def compute_fea_loss(self, real, fake):
|
||||
with torch.no_grad():
|
||||
logits_real = self.netF(real)
|
||||
logits_fake = self.netF(fake)
|
||||
logits_real = self.netF(real.to(self.device))
|
||||
logits_fake = self.netF(fake.to(self.device))
|
||||
return nn.L1Loss().to(self.device)(logits_fake, logits_real)
|
||||
|
||||
def test(self):
|
||||
|
|
|
@ -883,7 +883,7 @@ class SRGANModel(BaseModel):
|
|||
with torch.no_grad():
|
||||
real = real.unsqueeze(dim=0).to(self.device)
|
||||
fake = fake.unsqueeze(dim=0).to(self.device)
|
||||
real_fea = self.netF(real).detach()
|
||||
real_fea = self.netF(real)
|
||||
fake_fea = self.netF(fake)
|
||||
return self.cri_fea(fake_fea, real_fea).item()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user