diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index b79fd55f..a753e407 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -75,6 +75,7 @@ def gather_2d(input, index): return result +from utils.util import checkpoint class ConfigurableSwitchComputer(nn.Module): def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm, init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False): @@ -131,13 +132,13 @@ class ConfigurableSwitchComputer(nn.Module): x = self.pre_transform(*x) if not isinstance(x, tuple): x = (x,) - xformed = [t(*x) for t in self.transforms] + xformed = [checkpoint(t, *x) for t in self.transforms] if not isinstance(att_in, tuple): att_in = (att_in,) if self.feed_transforms_into_multiplexer: att_in = att_in + (torch.stack(xformed, dim=1),) - m = self.multiplexer(*att_in) + m = checkpoint(self.multiplexer, *att_in) # It is assumed that [xformed] and [m] are collapsed into tensors at this point. outputs, attention = self.switch(xformed, m, True, self.update_norm) @@ -603,4 +604,4 @@ if __name__ == '__main__': trans = [torch.randn(4,64,64,64) for t in range(10)] b = bb(x, r, cp) - emb(xu, b, trans) \ No newline at end of file + emb(xu, b, trans)