Use tensor checkpointing to drastically reduce memory usage

This comes at the expense of computation, but since we can use much larger
batches, it results in a net speedup.
This commit is contained in:
James Betker 2020-09-03 11:33:36 -06:00
parent 365813bde3
commit 696242064c

View File

@ -254,18 +254,14 @@ class ConfigurableSwitchComputer(nn.Module):
x = x1 + rand_feature x = x1 + rand_feature
if self.pre_transform: if self.pre_transform:
if isinstance(x, tuple): x = self.pre_transform(*x)
x = self.pre_transform(*x) if not isinstance(x, tuple):
else: x = (x,)
x = self.pre_transform(x) xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]
if isinstance(x, tuple):
xformed = [t.forward(*x) for t in self.transforms] if not isinstance(att_in, tuple):
else: att_in = (att_in,)
xformed = [t.forward(x) for t in self.transforms] m = torch.utils.checkpoint.checkpoint(self.multiplexer, *att_in)
if isinstance(att_in, tuple):
m = self.multiplexer(*att_in)
else:
m = self.multiplexer(att_in)
# It is assumed that [xformed] and [m] are collapsed into tensors at this point. # It is assumed that [xformed] and [m] are collapsed into tensors at this point.
outputs, attention = self.switch(xformed, m, True) outputs, attention = self.switch(xformed, m, True)