Use tensor checkpointing to drastically reduce memory usage
This comes at the expense of computation, but since we can use much larger batches, it results in a net speedup.
This commit is contained in:
parent
365813bde3
commit
696242064c
|
@ -254,18 +254,14 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
x = x1 + rand_feature
|
||||
|
||||
if self.pre_transform:
|
||||
if isinstance(x, tuple):
|
||||
x = self.pre_transform(*x)
|
||||
else:
|
||||
x = self.pre_transform(x)
|
||||
if isinstance(x, tuple):
|
||||
xformed = [t.forward(*x) for t in self.transforms]
|
||||
else:
|
||||
xformed = [t.forward(x) for t in self.transforms]
|
||||
if isinstance(att_in, tuple):
|
||||
m = self.multiplexer(*att_in)
|
||||
else:
|
||||
m = self.multiplexer(att_in)
|
||||
x = self.pre_transform(*x)
|
||||
if not isinstance(x, tuple):
|
||||
x = (x,)
|
||||
xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]
|
||||
|
||||
if not isinstance(att_in, tuple):
|
||||
att_in = (att_in,)
|
||||
m = torch.utils.checkpoint.checkpoint(self.multiplexer, *att_in)
|
||||
|
||||
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
||||
outputs, attention = self.switch(xformed, m, True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user