diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 7bcf73ca..aa0d740a 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -254,18 +254,14 @@ class ConfigurableSwitchComputer(nn.Module): x = x1 + rand_feature if self.pre_transform: - if isinstance(x, tuple): - x = self.pre_transform(*x) - else: - x = self.pre_transform(x) - if isinstance(x, tuple): - xformed = [t.forward(*x) for t in self.transforms] - else: - xformed = [t.forward(x) for t in self.transforms] - if isinstance(att_in, tuple): - m = self.multiplexer(*att_in) - else: - m = self.multiplexer(att_in) + x = self.pre_transform(*x) + if not isinstance(x, tuple): + x = (x,) + xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms] + + if not isinstance(att_in, tuple): + att_in = (att_in,) + m = torch.utils.checkpoint.checkpoint(self.multiplexer, *att_in) # It is assumed that [xformed] and [m] are collapsed into tensors at this point. outputs, attention = self.switch(xformed, m, True)