diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 85fcf63a..1928f39f 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -316,7 +316,7 @@ class TranslationInvarianceLoss(ConfigurableLoss): trans_output = net(*input) else: trans_output = net(*input) - if self.gen_output_to_use: + if self.gen_output_to_use is not None: fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh] else: fake_shared_output = trans_output[:, :, hl:hh, wl:wh]