Fix up fea_loss calculator (for validation)
Not sure how this was working in regular training mode, but it was failing in DDP.
This commit is contained in:
parent
21d3bb83b2
commit
3561cc164d
|
@ -240,8 +240,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
def compute_fea_loss(self, real, fake):
|
def compute_fea_loss(self, real, fake):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits_real = self.netF(real)
|
logits_real = self.netF(real.to(self.device))
|
||||||
logits_fake = self.netF(fake)
|
logits_fake = self.netF(fake.to(self.device))
|
||||||
return nn.L1Loss().to(self.device)(logits_fake, logits_real)
|
return nn.L1Loss().to(self.device)(logits_fake, logits_real)
|
||||||
|
|
||||||
def test(self):
|
def test(self):
|
||||||
|
|
|
@ -883,7 +883,7 @@ class SRGANModel(BaseModel):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
real = real.unsqueeze(dim=0).to(self.device)
|
real = real.unsqueeze(dim=0).to(self.device)
|
||||||
fake = fake.unsqueeze(dim=0).to(self.device)
|
fake = fake.unsqueeze(dim=0).to(self.device)
|
||||||
real_fea = self.netF(real).detach()
|
real_fea = self.netF(real)
|
||||||
fake_fea = self.netF(fake)
|
fake_fea = self.netF(fake)
|
||||||
return self.cri_fea(fake_fea, real_fea).item()
|
return self.cri_fea(fake_fea, real_fea).item()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user