Fix feature validation, wrong device
Only shows up in distributed training for some reason.
This commit is contained in:
parent
a7541b6d8d
commit
9ccf771629
|
@ -554,8 +554,8 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
def compute_fea_loss(self, real, fake):
|
def compute_fea_loss(self, real, fake):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
real = real.unsqueeze(dim=0)
|
real = real.unsqueeze(dim=0).to(self.device)
|
||||||
fake = fake.unsqueeze(dim=0)
|
fake = fake.unsqueeze(dim=0).to(self.device)
|
||||||
real_fea = self.netF(real).detach()
|
real_fea = self.netF(real).detach()
|
||||||
fake_fea = self.netF(fake)
|
fake_fea = self.netF(fake)
|
||||||
return self.cri_fea(fake_fea, real_fea).item()
|
return self.cri_fea(fake_fea, real_fea).item()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user