Use tensor checkpointing to drastically reduce memory usage

This comes at the expense of computation, but since we can use much larger
batches, it results in a net speedup.
This commit is contained in:
James Betker 2020-09-03 11:33:36 -06:00
parent 365813bde3
commit 696242064c

View File

@ -254,18 +254,14 @@ class ConfigurableSwitchComputer(nn.Module):
x = x1 + rand_feature
if self.pre_transform:
if isinstance(x, tuple):
x = self.pre_transform(*x)
else:
x = self.pre_transform(x)
if isinstance(x, tuple):
xformed = [t.forward(*x) for t in self.transforms]
else:
xformed = [t.forward(x) for t in self.transforms]
if isinstance(att_in, tuple):
m = self.multiplexer(*att_in)
else:
m = self.multiplexer(att_in)
x = self.pre_transform(*x)
if not isinstance(x, tuple):
x = (x,)
xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]
if not isinstance(att_in, tuple):
att_in = (att_in,)
m = torch.utils.checkpoint.checkpoint(self.multiplexer, *att_in)
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
outputs, attention = self.switch(xformed, m, True)