diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 34106d1a..ce12a315 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -138,9 +138,9 @@ class ConfigurableSwitchComputer(nn.Module): x = self.pre_transform(x) xformed = [t.forward(x) for t in self.transforms] - m = self.multiplexer(identity) + outputs, attention = self.switch(xformed, m, True) outputs = identity + outputs * self.switch_scale * fixed_scale outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale @@ -186,6 +186,10 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): assert self.upsample_factor == 2 or self.upsample_factor == 4 def forward(self, x): + # This is a common bug when evaluating SRG2 generators. It needs to be configured properly in eval mode. Just fail. + if not self.train: + assert self.switches[0].switch.temperature == 1 + x = self.initial_conv(x) self.attentions = []