Fix loss indexing

This commit is contained in:
James Betker 2020-10-09 20:20:51 -06:00
parent 202eb11fdc
commit 0011d445c8

View File

@ -413,5 +413,5 @@ class ForElementLoss(ConfigurableLoss):
def forward(self, net, state):
st = state.copy()
st['_real'] = state[self.opt['real']][:, self.index]
st['_fake'] = state[self.opt['fake']][:, self.index]
st['_fake'] = state[self.opt['fake']][self.index]
return self.loss(net, st)