Allow SRG checkpointing to be toggled

This commit is contained in:
James Betker 2020-09-08 15:14:43 -06:00
parent e6207d4c50
commit dffbfd2ec4

View File

@ -7,6 +7,11 @@ from collections import OrderedDict
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu
from switched_conv_util import save_attention_to_image_rgb from switched_conv_util import save_attention_to_image_rgb
import os import os
from torch.utils.checkpoint import checkpoint
# Set to true to relieve memory pressure by using torch.utils.checkpoint in several memory-critical locations.
memory_checkpointing_enabled = False
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation # VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
@ -219,11 +224,17 @@ class ConfigurableSwitchComputer(nn.Module):
x = self.pre_transform(*x) x = self.pre_transform(*x)
if not isinstance(x, tuple): if not isinstance(x, tuple):
x = (x,) x = (x,)
xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms] if memory_checkpointing_enabled:
xformed = [checkpoint(t, *x) for t in self.transforms]
else:
xformed = [t(*x) for t in self.transforms]
if not isinstance(att_in, tuple): if not isinstance(att_in, tuple):
att_in = (att_in,) att_in = (att_in,)
m = torch.utils.checkpoint.checkpoint(self.multiplexer, *att_in) if memory_checkpointing_enabled:
m = checkpoint(self.multiplexer, *att_in)
else:
m = self.multiplexer(*att_in)
# It is assumed that [xformed] and [m] are collapsed into tensors at this point. # It is assumed that [xformed] and [m] are collapsed into tensors at this point.
outputs, attention = self.switch(xformed, m, True) outputs, attention = self.switch(xformed, m, True)