From 2e18c4c22d1f1750c934922d6434dd15f9de1166 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 22 Sep 2020 17:10:29 -0600 Subject: [PATCH] Add CosineEmbeddingLoss to F --- codes/models/steps/losses.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index d367d318..69735a89 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -51,6 +51,8 @@ def get_basic_criterion_for_name(name, device): return nn.L1Loss().to(device) elif name == 'l2': return nn.MSELoss().to(device) + elif name == 'cosine': + return nn.CosineEmbeddingLoss().to(device) else: raise NotImplementedError @@ -79,7 +81,10 @@ class FeatureLoss(ConfigurableLoss): with torch.no_grad(): logits_real = self.netF(state[self.opt['real']]) logits_fake = self.netF(state[self.opt['fake']]) - return self.criterion(logits_fake, logits_real) + if self.opt['criterion'] == 'cosine': + return self.criterion(logits_fake, logits_real, 1) + else: + return self.criterion(logits_fake, logits_real) # Special form of feature loss which first computes the feature embedding for the truth space, then uses a second