Fix loss indexing
This commit is contained in:
parent
202eb11fdc
commit
0011d445c8
|
@ -413,5 +413,5 @@ class ForElementLoss(ConfigurableLoss):
|
|||
def forward(self, net, state):
|
||||
st = state.copy()
|
||||
st['_real'] = state[self.opt['real']][:, self.index]
|
||||
st['_fake'] = state[self.opt['fake']][:, self.index]
|
||||
st['_fake'] = state[self.opt['fake']][self.index]
|
||||
return self.loss(net, st)
|
||||
|
|
Loading…
Reference in New Issue
Block a user