forked from mrq/DL-Art-School
Allow recurrent losses to be weighted
This commit is contained in:
parent
5753e77d67
commit
931aa65dd0
|
@ -392,6 +392,10 @@ class RecurrentLoss(ConfigurableLoss):
|
|||
o['fake'] = '_fake'
|
||||
o['real'] = '_real'
|
||||
self.loss = create_loss(o, self.env)
|
||||
# Use this option to specify a differential weighting scheme for losses inside of the recurrent construct. For
|
||||
# example, if later recurrent outputs should contribute more to the loss than earlier ones. When specified,
|
||||
# must be a list of weights that exactly aligns with the recurrent list fed to forward().
|
||||
self.recurrent_weights = opt['recurrent_weights'] if 'recurrent_weights' in opt.keys() else 1
|
||||
|
||||
def forward(self, net, state):
|
||||
total_loss = 0
|
||||
|
@ -400,7 +404,10 @@ class RecurrentLoss(ConfigurableLoss):
|
|||
for i in range(real.shape[1]):
|
||||
st['_real'] = real[:, i]
|
||||
st['_fake'] = state[self.opt['fake']][:, i]
|
||||
total_loss += self.loss(net, st)
|
||||
subloss = self.loss(net, st)
|
||||
if isinstance(self.recurrent_weights, list);
|
||||
subloss = subloss * self.recurrent_weights[i]
|
||||
total_loss += subloss
|
||||
return total_loss
|
||||
|
||||
def extra_metrics(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user