forked from mrq/DL-Art-School
Use tensor checkpointing to drastically reduce memory usage
This comes at the expense of computation, but since we can use much larger batches, it results in a net speedup.
This commit is contained in:
parent
365813bde3
commit
696242064c
|
@ -254,18 +254,14 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
x = x1 + rand_feature
|
x = x1 + rand_feature
|
||||||
|
|
||||||
if self.pre_transform:
|
if self.pre_transform:
|
||||||
if isinstance(x, tuple):
|
x = self.pre_transform(*x)
|
||||||
x = self.pre_transform(*x)
|
if not isinstance(x, tuple):
|
||||||
else:
|
x = (x,)
|
||||||
x = self.pre_transform(x)
|
xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]
|
||||||
if isinstance(x, tuple):
|
|
||||||
xformed = [t.forward(*x) for t in self.transforms]
|
if not isinstance(att_in, tuple):
|
||||||
else:
|
att_in = (att_in,)
|
||||||
xformed = [t.forward(x) for t in self.transforms]
|
m = torch.utils.checkpoint.checkpoint(self.multiplexer, *att_in)
|
||||||
if isinstance(att_in, tuple):
|
|
||||||
m = self.multiplexer(*att_in)
|
|
||||||
else:
|
|
||||||
m = self.multiplexer(att_in)
|
|
||||||
|
|
||||||
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
||||||
outputs, attention = self.switch(xformed, m, True)
|
outputs, attention = self.switch(xformed, m, True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user