forked from mrq/DL-Art-School
Add CosineEmbeddingLoss to F
This commit is contained in:
parent
f40beb5460
commit
2e18c4c22d
|
@ -51,6 +51,8 @@ def get_basic_criterion_for_name(name, device):
|
||||||
return nn.L1Loss().to(device)
|
return nn.L1Loss().to(device)
|
||||||
elif name == 'l2':
|
elif name == 'l2':
|
||||||
return nn.MSELoss().to(device)
|
return nn.MSELoss().to(device)
|
||||||
|
elif name == 'cosine':
|
||||||
|
return nn.CosineEmbeddingLoss().to(device)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -79,7 +81,10 @@ class FeatureLoss(ConfigurableLoss):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits_real = self.netF(state[self.opt['real']])
|
logits_real = self.netF(state[self.opt['real']])
|
||||||
logits_fake = self.netF(state[self.opt['fake']])
|
logits_fake = self.netF(state[self.opt['fake']])
|
||||||
return self.criterion(logits_fake, logits_real)
|
if self.opt['criterion'] == 'cosine':
|
||||||
|
return self.criterion(logits_fake, logits_real, 1)
|
||||||
|
else:
|
||||||
|
return self.criterion(logits_fake, logits_real)
|
||||||
|
|
||||||
|
|
||||||
# Special form of feature loss which first computes the feature embedding for the truth space, then uses a second
|
# Special form of feature loss which first computes the feature embedding for the truth space, then uses a second
|
||||||
|
|
Loading…
Reference in New Issue
Block a user