From 106b8da31523bf706c8ed40def76c2c05f31ec78 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 22 Jul 2020 20:50:59 -0600 Subject: [PATCH] Assert that temperature is set properly in eval mode. --- codes/models/archs/SwitchedResidualGenerator_arch.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 34106d1a..ce12a315 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -138,9 +138,9 @@ class ConfigurableSwitchComputer(nn.Module): x = self.pre_transform(x) xformed = [t.forward(x) for t in self.transforms] - m = self.multiplexer(identity) + outputs, attention = self.switch(xformed, m, True) outputs = identity + outputs * self.switch_scale * fixed_scale outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale @@ -186,6 +186,10 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): assert self.upsample_factor == 2 or self.upsample_factor == 4 def forward(self, x): + # This is a common bug when evaluating SRG2 generators. It needs to be configured properly in eval mode. Just fail. + if not self.train: + assert self.switches[0].switch.temperature == 1 + x = self.initial_conv(x) self.attentions = []