From 696242064c85f18d5fe2da754013cdd1f264b139 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 3 Sep 2020 11:33:36 -0600 Subject: [PATCH] Use tensor checkpointing to drastically reduce memory usage This comes at the expense of computation, but since we can use much larger batches, it results in a net speedup. --- .../archs/SwitchedResidualGenerator_arch.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 7bcf73ca..aa0d740a 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -254,18 +254,14 @@ class ConfigurableSwitchComputer(nn.Module): x = x1 + rand_feature if self.pre_transform: - if isinstance(x, tuple): - x = self.pre_transform(*x) - else: - x = self.pre_transform(x) - if isinstance(x, tuple): - xformed = [t.forward(*x) for t in self.transforms] - else: - xformed = [t.forward(x) for t in self.transforms] - if isinstance(att_in, tuple): - m = self.multiplexer(*att_in) - else: - m = self.multiplexer(att_in) + x = self.pre_transform(*x) + if not isinstance(x, tuple): + x = (x,) + xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms] + + if not isinstance(att_in, tuple): + att_in = (att_in,) + m = torch.utils.checkpoint.checkpoint(self.multiplexer, *att_in) # It is assumed that [xformed] and [m] are collapsed into tensors at this point. outputs, attention = self.switch(xformed, m, True)