checkpoint in ssg
This commit is contained in:
parent
05377973bf
commit
731700ab2c
|
@ -75,6 +75,7 @@ def gather_2d(input, index):
|
|||
return result
|
||||
|
||||
|
||||
from utils.util import checkpoint
|
||||
class ConfigurableSwitchComputer(nn.Module):
|
||||
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
|
||||
init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False):
|
||||
|
@ -131,13 +132,13 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
x = self.pre_transform(*x)
|
||||
if not isinstance(x, tuple):
|
||||
x = (x,)
|
||||
xformed = [t(*x) for t in self.transforms]
|
||||
xformed = [checkpoint(t, *x) for t in self.transforms]
|
||||
|
||||
if not isinstance(att_in, tuple):
|
||||
att_in = (att_in,)
|
||||
if self.feed_transforms_into_multiplexer:
|
||||
att_in = att_in + (torch.stack(xformed, dim=1),)
|
||||
m = self.multiplexer(*att_in)
|
||||
m = checkpoint(self.multiplexer, *att_in)
|
||||
|
||||
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
||||
outputs, attention = self.switch(xformed, m, True, self.update_norm)
|
||||
|
@ -603,4 +604,4 @@ if __name__ == '__main__':
|
|||
trans = [torch.randn(4,64,64,64) for t in range(10)]
|
||||
|
||||
b = bb(x, r, cp)
|
||||
emb(xu, b, trans)
|
||||
emb(xu, b, trans)
|
||||
|
|
Loading…
Reference in New Issue
Block a user