Update losses
This commit is contained in:
parent
9b3c3b1227
commit
d43f25cc20
|
@ -26,6 +26,8 @@ def create_loss(opt_loss, env):
|
|||
return LightweightGanDivergenceLoss(opt_loss, env)
|
||||
elif type == 'crossentropy':
|
||||
return CrossEntropy(opt_loss, env)
|
||||
elif type == 'distillation':
|
||||
return Distillation(opt_loss, env)
|
||||
elif type == 'pix':
|
||||
return PixLoss(opt_loss, env)
|
||||
elif type == 'sr_pix':
|
||||
|
@ -122,7 +124,7 @@ class CrossEntropy(ConfigurableLoss):
|
|||
super().__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.subtype = opt_get(opt, ['subtype'], 'ce')
|
||||
if self.subtype == 'ce':
|
||||
if self.subtype == 'ce' or self.subtype == 'soft_ce':
|
||||
self.ce = nn.CrossEntropyLoss()
|
||||
elif self.subtype == 'bce':
|
||||
self.ce = nn.BCEWithLogitsLoss()
|
||||
|
@ -144,13 +146,33 @@ class CrossEntropy(ConfigurableLoss):
|
|||
if self.subtype == 'bce':
|
||||
logits = logits.reshape(-1, 1)
|
||||
labels = labels.reshape(-1, 1)
|
||||
else:
|
||||
elif self.subtype == 'ce':
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
labels = labels.view(-1)
|
||||
assert labels.max()+1 <= logits.shape[-1]
|
||||
elif self.subtype == 'soft_ce':
|
||||
labels = F.softmax(labels, dim=1)
|
||||
return F.cross_entropy(logits, labels)
|
||||
return self.ce(logits, labels)
|
||||
|
||||
|
||||
class Distillation(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.teacher = opt['teacher']
|
||||
self.student = opt['student']
|
||||
self.loss = nn.KLDivLoss(reduction='batchmean')
|
||||
self.temperature = opt_get(opt, ['temperature'], 1.0)
|
||||
|
||||
def forward(self, _, state):
|
||||
# Current assumption is that both logits are of shape [b,C,d], b=batch,C=class_logits,d=sequence_len
|
||||
teacher = state[self.teacher].permute(0,2,1)
|
||||
student = state[self.student].permute(0,2,1)
|
||||
|
||||
return self.loss(input=F.log_softmax(student/self.temperature, dim=-1), target=F.softmax(teacher/self.temperature, dim=-1))
|
||||
|
||||
|
||||
class PixLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super(PixLoss, self).__init__(opt, env)
|
||||
|
|
Loading…
Reference in New Issue
Block a user