diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index d70a2ee6..c5419a78 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -392,6 +392,10 @@ class RecurrentLoss(ConfigurableLoss): o['fake'] = '_fake' o['real'] = '_real' self.loss = create_loss(o, self.env) + # Use this option to specify a differential weighting scheme for losses inside of the recurrent construct. For + # example, if later recurrent outputs should contribute more to the loss than earlier ones. When specified, + # must be a list of weights that exactly aligns with the recurrent list fed to forward(). + self.recurrent_weights = opt['recurrent_weights'] if 'recurrent_weights' in opt.keys() else 1 def forward(self, net, state): total_loss = 0 @@ -400,7 +404,10 @@ class RecurrentLoss(ConfigurableLoss): for i in range(real.shape[1]): st['_real'] = real[:, i] st['_fake'] = state[self.opt['fake']][:, i] - total_loss += self.loss(net, st) + subloss = self.loss(net, st) + if isinstance(self.recurrent_weights, list); + subloss = subloss * self.recurrent_weights[i] + total_loss += subloss return total_loss def extra_metrics(self):