Fix feature validation, wrong device

Only shows up in distributed training for some reason.
This commit is contained in:
James Betker 2020-07-23 10:16:34 -06:00
parent a7541b6d8d
commit 9ccf771629

View File

@ -554,8 +554,8 @@ class SRGANModel(BaseModel):
def compute_fea_loss(self, real, fake): def compute_fea_loss(self, real, fake):
with torch.no_grad(): with torch.no_grad():
real = real.unsqueeze(dim=0) real = real.unsqueeze(dim=0).to(self.device)
fake = fake.unsqueeze(dim=0) fake = fake.unsqueeze(dim=0).to(self.device)
real_fea = self.netF(real).detach() real_fea = self.netF(real).detach()
fake_fea = self.netF(fake) fake_fea = self.netF(fake)
return self.cri_fea(fake_fea, real_fea).item() return self.cri_fea(fake_fea, real_fea).item()