checkpoint in ssg

This commit is contained in:
James Betker 2020-10-12 17:43:28 -06:00
parent 05377973bf
commit 731700ab2c

View File

@ -75,6 +75,7 @@ def gather_2d(input, index):
return result
from utils.util import checkpoint
class ConfigurableSwitchComputer(nn.Module):
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False):
@ -131,13 +132,13 @@ class ConfigurableSwitchComputer(nn.Module):
x = self.pre_transform(*x)
if not isinstance(x, tuple):
x = (x,)
xformed = [t(*x) for t in self.transforms]
xformed = [checkpoint(t, *x) for t in self.transforms]
if not isinstance(att_in, tuple):
att_in = (att_in,)
if self.feed_transforms_into_multiplexer:
att_in = att_in + (torch.stack(xformed, dim=1),)
m = self.multiplexer(*att_in)
m = checkpoint(self.multiplexer, *att_in)
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
outputs, attention = self.switch(xformed, m, True, self.update_norm)
@ -603,4 +604,4 @@ if __name__ == '__main__':
trans = [torch.randn(4,64,64,64) for t in range(10)]
b = bb(x, r, cp)
emb(xu, b, trans)
emb(xu, b, trans)