From d43f25cc208a26dc4b70a7f069bde4ada437c5bf Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 8 Nov 2021 20:10:07 -0700 Subject: [PATCH] Update losses --- codes/trainer/losses.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/codes/trainer/losses.py b/codes/trainer/losses.py index e8567812..980262c0 100644 --- a/codes/trainer/losses.py +++ b/codes/trainer/losses.py @@ -26,6 +26,8 @@ def create_loss(opt_loss, env): return LightweightGanDivergenceLoss(opt_loss, env) elif type == 'crossentropy': return CrossEntropy(opt_loss, env) + elif type == 'distillation': + return Distillation(opt_loss, env) elif type == 'pix': return PixLoss(opt_loss, env) elif type == 'sr_pix': @@ -122,7 +124,7 @@ class CrossEntropy(ConfigurableLoss): super().__init__(opt, env) self.opt = opt self.subtype = opt_get(opt, ['subtype'], 'ce') - if self.subtype == 'ce': + if self.subtype == 'ce' or self.subtype == 'soft_ce': self.ce = nn.CrossEntropyLoss() elif self.subtype == 'bce': self.ce = nn.BCEWithLogitsLoss() @@ -144,13 +146,33 @@ class CrossEntropy(ConfigurableLoss): if self.subtype == 'bce': logits = logits.reshape(-1, 1) labels = labels.reshape(-1, 1) - else: + elif self.subtype == 'ce': logits = logits.view(-1, logits.size(-1)) labels = labels.view(-1) assert labels.max()+1 <= logits.shape[-1] + elif self.subtype == 'soft_ce': + labels = F.softmax(labels, dim=1) + return F.cross_entropy(logits, labels) return self.ce(logits, labels) +class Distillation(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.opt = opt + self.teacher = opt['teacher'] + self.student = opt['student'] + self.loss = nn.KLDivLoss(reduction='batchmean') + self.temperature = opt_get(opt, ['temperature'], 1.0) + + def forward(self, _, state): + # Current assumption is that both logits are of shape [b,C,d], b=batch,C=class_logits,d=sequence_len + teacher = state[self.teacher].permute(0,2,1) + student = state[self.student].permute(0,2,1) + + return self.loss(input=F.log_softmax(student/self.temperature, dim=-1), target=F.softmax(teacher/self.temperature, dim=-1)) + + class PixLoss(ConfigurableLoss): def __init__(self, opt, env): super(PixLoss, self).__init__(opt, env)