Update losses

This commit is contained in:
James Betker 2021-11-08 20:10:07 -07:00
parent 9b3c3b1227
commit d43f25cc20

View File

@ -26,6 +26,8 @@ def create_loss(opt_loss, env):
return LightweightGanDivergenceLoss(opt_loss, env)
elif type == 'crossentropy':
return CrossEntropy(opt_loss, env)
elif type == 'distillation':
return Distillation(opt_loss, env)
elif type == 'pix':
return PixLoss(opt_loss, env)
elif type == 'sr_pix':
@ -122,7 +124,7 @@ class CrossEntropy(ConfigurableLoss):
super().__init__(opt, env)
self.opt = opt
self.subtype = opt_get(opt, ['subtype'], 'ce')
if self.subtype == 'ce':
if self.subtype == 'ce' or self.subtype == 'soft_ce':
self.ce = nn.CrossEntropyLoss()
elif self.subtype == 'bce':
self.ce = nn.BCEWithLogitsLoss()
@ -144,13 +146,33 @@ class CrossEntropy(ConfigurableLoss):
if self.subtype == 'bce':
logits = logits.reshape(-1, 1)
labels = labels.reshape(-1, 1)
else:
elif self.subtype == 'ce':
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
assert labels.max()+1 <= logits.shape[-1]
elif self.subtype == 'soft_ce':
labels = F.softmax(labels, dim=1)
return F.cross_entropy(logits, labels)
return self.ce(logits, labels)
class Distillation(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.opt = opt
self.teacher = opt['teacher']
self.student = opt['student']
self.loss = nn.KLDivLoss(reduction='batchmean')
self.temperature = opt_get(opt, ['temperature'], 1.0)
def forward(self, _, state):
# Current assumption is that both logits are of shape [b,C,d], b=batch,C=class_logits,d=sequence_len
teacher = state[self.teacher].permute(0,2,1)
student = state[self.student].permute(0,2,1)
return self.loss(input=F.log_softmax(student/self.temperature, dim=-1), target=F.softmax(teacher/self.temperature, dim=-1))
class PixLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(PixLoss, self).__init__(opt, env)