Add CosineEmbeddingLoss to F

This commit is contained in:
James Betker 2020-09-22 17:10:29 -06:00
parent f40beb5460
commit 2e18c4c22d

View File

@ -51,6 +51,8 @@ def get_basic_criterion_for_name(name, device):
return nn.L1Loss().to(device)
elif name == 'l2':
return nn.MSELoss().to(device)
elif name == 'cosine':
return nn.CosineEmbeddingLoss().to(device)
else:
raise NotImplementedError
@ -79,7 +81,10 @@ class FeatureLoss(ConfigurableLoss):
with torch.no_grad():
logits_real = self.netF(state[self.opt['real']])
logits_fake = self.netF(state[self.opt['fake']])
return self.criterion(logits_fake, logits_real)
if self.opt['criterion'] == 'cosine':
return self.criterion(logits_fake, logits_real, 1)
else:
return self.criterion(logits_fake, logits_real)
# Special form of feature loss which first computes the feature embedding for the truth space, then uses a second