Allow recurrent losses to be weighted

This commit is contained in:
James Betker 2020-10-21 16:59:44 -06:00
parent 5753e77d67
commit 931aa65dd0

View File

@ -392,6 +392,10 @@ class RecurrentLoss(ConfigurableLoss):
o['fake'] = '_fake' o['fake'] = '_fake'
o['real'] = '_real' o['real'] = '_real'
self.loss = create_loss(o, self.env) self.loss = create_loss(o, self.env)
# Use this option to specify a differential weighting scheme for losses inside of the recurrent construct. For
# example, if later recurrent outputs should contribute more to the loss than earlier ones. When specified,
# must be a list of weights that exactly aligns with the recurrent list fed to forward().
self.recurrent_weights = opt['recurrent_weights'] if 'recurrent_weights' in opt.keys() else 1
def forward(self, net, state): def forward(self, net, state):
total_loss = 0 total_loss = 0
@ -400,7 +404,10 @@ class RecurrentLoss(ConfigurableLoss):
for i in range(real.shape[1]): for i in range(real.shape[1]):
st['_real'] = real[:, i] st['_real'] = real[:, i]
st['_fake'] = state[self.opt['fake']][:, i] st['_fake'] = state[self.opt['fake']][:, i]
total_loss += self.loss(net, st) subloss = self.loss(net, st)
if isinstance(self.recurrent_weights, list);
subloss = subloss * self.recurrent_weights[i]
total_loss += subloss
return total_loss return total_loss
def extra_metrics(self): def extra_metrics(self):