diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index bea9035d..5b9857c5 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -281,7 +281,7 @@ class StackedSwitchGenerator(nn.Module): x_out = checkpoint(self.final_hr_conv2, x_out) self.attentions = [a1, a3, a3] - return x_out + return x_out, def set_temperature(self, temp): [sw.set_temperature(temp) for sw in self.switches]