forked from mrq/DL-Art-School
Allow recurrent losses to be weighted
This commit is contained in:
parent
5753e77d67
commit
931aa65dd0
|
@ -392,6 +392,10 @@ class RecurrentLoss(ConfigurableLoss):
|
||||||
o['fake'] = '_fake'
|
o['fake'] = '_fake'
|
||||||
o['real'] = '_real'
|
o['real'] = '_real'
|
||||||
self.loss = create_loss(o, self.env)
|
self.loss = create_loss(o, self.env)
|
||||||
|
# Use this option to specify a differential weighting scheme for losses inside of the recurrent construct. For
|
||||||
|
# example, if later recurrent outputs should contribute more to the loss than earlier ones. When specified,
|
||||||
|
# must be a list of weights that exactly aligns with the recurrent list fed to forward().
|
||||||
|
self.recurrent_weights = opt['recurrent_weights'] if 'recurrent_weights' in opt.keys() else 1
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
@ -400,7 +404,10 @@ class RecurrentLoss(ConfigurableLoss):
|
||||||
for i in range(real.shape[1]):
|
for i in range(real.shape[1]):
|
||||||
st['_real'] = real[:, i]
|
st['_real'] = real[:, i]
|
||||||
st['_fake'] = state[self.opt['fake']][:, i]
|
st['_fake'] = state[self.opt['fake']][:, i]
|
||||||
total_loss += self.loss(net, st)
|
subloss = self.loss(net, st)
|
||||||
|
if isinstance(self.recurrent_weights, list);
|
||||||
|
subloss = subloss * self.recurrent_weights[i]
|
||||||
|
total_loss += subloss
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
def extra_metrics(self):
|
def extra_metrics(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user