checkpoint in ssg

This commit is contained in:
James Betker 2020-10-12 17:43:28 -06:00
parent 05377973bf
commit 731700ab2c

View File

@ -75,6 +75,7 @@ def gather_2d(input, index):
return result return result
from utils.util import checkpoint
class ConfigurableSwitchComputer(nn.Module): class ConfigurableSwitchComputer(nn.Module):
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm, def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False): init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False):
@ -131,13 +132,13 @@ class ConfigurableSwitchComputer(nn.Module):
x = self.pre_transform(*x) x = self.pre_transform(*x)
if not isinstance(x, tuple): if not isinstance(x, tuple):
x = (x,) x = (x,)
xformed = [t(*x) for t in self.transforms] xformed = [checkpoint(t, *x) for t in self.transforms]
if not isinstance(att_in, tuple): if not isinstance(att_in, tuple):
att_in = (att_in,) att_in = (att_in,)
if self.feed_transforms_into_multiplexer: if self.feed_transforms_into_multiplexer:
att_in = att_in + (torch.stack(xformed, dim=1),) att_in = att_in + (torch.stack(xformed, dim=1),)
m = self.multiplexer(*att_in) m = checkpoint(self.multiplexer, *att_in)
# It is assumed that [xformed] and [m] are collapsed into tensors at this point. # It is assumed that [xformed] and [m] are collapsed into tensors at this point.
outputs, attention = self.switch(xformed, m, True, self.update_norm) outputs, attention = self.switch(xformed, m, True, self.update_norm)