diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 69735a89..cd1795fb 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -82,7 +82,7 @@ class FeatureLoss(ConfigurableLoss): logits_real = self.netF(state[self.opt['real']]) logits_fake = self.netF(state[self.opt['fake']]) if self.opt['criterion'] == 'cosine': - return self.criterion(logits_fake, logits_real, 1) + return self.criterion(logits_fake, logits_real, torch.ones_like(logits_fake)) else: return self.criterion(logits_fake, logits_real)