checkpoint in ssg
This commit is contained in:
parent
05377973bf
commit
731700ab2c
|
@ -75,6 +75,7 @@ def gather_2d(input, index):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
from utils.util import checkpoint
|
||||||
class ConfigurableSwitchComputer(nn.Module):
|
class ConfigurableSwitchComputer(nn.Module):
|
||||||
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
|
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
|
||||||
init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False):
|
init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False):
|
||||||
|
@ -131,13 +132,13 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
x = self.pre_transform(*x)
|
x = self.pre_transform(*x)
|
||||||
if not isinstance(x, tuple):
|
if not isinstance(x, tuple):
|
||||||
x = (x,)
|
x = (x,)
|
||||||
xformed = [t(*x) for t in self.transforms]
|
xformed = [checkpoint(t, *x) for t in self.transforms]
|
||||||
|
|
||||||
if not isinstance(att_in, tuple):
|
if not isinstance(att_in, tuple):
|
||||||
att_in = (att_in,)
|
att_in = (att_in,)
|
||||||
if self.feed_transforms_into_multiplexer:
|
if self.feed_transforms_into_multiplexer:
|
||||||
att_in = att_in + (torch.stack(xformed, dim=1),)
|
att_in = att_in + (torch.stack(xformed, dim=1),)
|
||||||
m = self.multiplexer(*att_in)
|
m = checkpoint(self.multiplexer, *att_in)
|
||||||
|
|
||||||
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
||||||
outputs, attention = self.switch(xformed, m, True, self.update_norm)
|
outputs, attention = self.switch(xformed, m, True, self.update_norm)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user