Use tensor checkpointing to drastically reduce memory usage

This comes at the expense of computation, but since we can use much larger
batches, it results in a net speedup.
This commit is contained in:
James Betker 2020-09-03 11:33:36 -06:00
parent 365813bde3
commit 696242064c

View File

@ -254,18 +254,14 @@ class ConfigurableSwitchComputer(nn.Module):
x = x1 + rand_feature x = x1 + rand_feature
if self.pre_transform: if self.pre_transform:
if isinstance(x, tuple):
x = self.pre_transform(*x) x = self.pre_transform(*x)
else: if not isinstance(x, tuple):
x = self.pre_transform(x) x = (x,)
if isinstance(x, tuple): xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]
xformed = [t.forward(*x) for t in self.transforms]
else: if not isinstance(att_in, tuple):
xformed = [t.forward(x) for t in self.transforms] att_in = (att_in,)
if isinstance(att_in, tuple): m = torch.utils.checkpoint.checkpoint(self.multiplexer, *att_in)
m = self.multiplexer(*att_in)
else:
m = self.multiplexer(att_in)
# It is assumed that [xformed] and [m] are collapsed into tensors at this point. # It is assumed that [xformed] and [m] are collapsed into tensors at this point.
outputs, attention = self.switch(xformed, m, True) outputs, attention = self.switch(xformed, m, True)