Fix feature validation, wrong device
Only shows up in distributed training for some reason.
This commit is contained in:
parent
a7541b6d8d
commit
9ccf771629
|
@ -554,8 +554,8 @@ class SRGANModel(BaseModel):
|
|||
|
||||
def compute_fea_loss(self, real, fake):
|
||||
with torch.no_grad():
|
||||
real = real.unsqueeze(dim=0)
|
||||
fake = fake.unsqueeze(dim=0)
|
||||
real = real.unsqueeze(dim=0).to(self.device)
|
||||
fake = fake.unsqueeze(dim=0).to(self.device)
|
||||
real_fea = self.netF(real).detach()
|
||||
fake_fea = self.netF(fake)
|
||||
return self.cri_fea(fake_fea, real_fea).item()
|
||||
|
|
Loading…
Reference in New Issue
Block a user