diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 9e7403bd..7ac786c8 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -175,6 +175,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): self.heightened_final_step = heightened_final_step self.attentions = None self.upsample_factor = upsample_factor + assert self.upsample_factor == 2 or self.upsample_factor == 4 def forward(self, x): x = self.initial_conv(x) @@ -186,7 +187,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): self.attentions.append(att) x = swx + self.sw_conv(x) - assert x == 2 or x == 4 x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) if self.upsample_factor > 2: x = F.interpolate(x, scale_factor=2, mode="nearest")