diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 9543ed8f..e2f60783 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -413,5 +413,5 @@ class ForElementLoss(ConfigurableLoss): def forward(self, net, state): st = state.copy() st['_real'] = state[self.opt['real']][:, self.index] - st['_fake'] = state[self.opt['fake']][:, self.index] + st['_fake'] = state[self.opt['fake']][self.index] return self.loss(net, st)