diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index f6878aab..aa0d740a 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -254,7 +254,7 @@ class ConfigurableSwitchComputer(nn.Module): x = x1 + rand_feature if self.pre_transform: - x = self.pre_transform(x) + x = self.pre_transform(*x) if not isinstance(x, tuple): x = (x,) xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]