Fix loss indexing
This commit is contained in:
parent
202eb11fdc
commit
0011d445c8
|
@ -413,5 +413,5 @@ class ForElementLoss(ConfigurableLoss):
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
st = state.copy()
|
st = state.copy()
|
||||||
st['_real'] = state[self.opt['real']][:, self.index]
|
st['_real'] = state[self.opt['real']][:, self.index]
|
||||||
st['_fake'] = state[self.opt['fake']][:, self.index]
|
st['_fake'] = state[self.opt['fake']][self.index]
|
||||||
return self.loss(net, st)
|
return self.loss(net, st)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user