Fix feature validation, wrong device

Only shows up in distributed training for some reason.
This commit is contained in:
James Betker 2020-07-23 10:16:34 -06:00
parent a7541b6d8d
commit 9ccf771629

View File

@ -554,8 +554,8 @@ class SRGANModel(BaseModel):
def compute_fea_loss(self, real, fake):
with torch.no_grad():
real = real.unsqueeze(dim=0)
fake = fake.unsqueeze(dim=0)
real = real.unsqueeze(dim=0).to(self.device)
fake = fake.unsqueeze(dim=0).to(self.device)
real_fea = self.netF(real).detach()
fake_fea = self.netF(fake)
return self.cri_fea(fake_fea, real_fea).item()