Allow SRG checkpointing to be toggled
This commit is contained in:
parent
e6207d4c50
commit
dffbfd2ec4
|
@ -7,6 +7,11 @@ from collections import OrderedDict
|
|||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu
|
||||
from switched_conv_util import save_attention_to_image_rgb
|
||||
import os
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
# Set to true to relieve memory pressure by using torch.utils.checkpoint in several memory-critical locations.
|
||||
memory_checkpointing_enabled = False
|
||||
|
||||
|
||||
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
||||
|
@ -219,11 +224,17 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
x = self.pre_transform(*x)
|
||||
if not isinstance(x, tuple):
|
||||
x = (x,)
|
||||
xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms]
|
||||
if memory_checkpointing_enabled:
|
||||
xformed = [checkpoint(t, *x) for t in self.transforms]
|
||||
else:
|
||||
xformed = [t(*x) for t in self.transforms]
|
||||
|
||||
if not isinstance(att_in, tuple):
|
||||
att_in = (att_in,)
|
||||
m = torch.utils.checkpoint.checkpoint(self.multiplexer, *att_in)
|
||||
if memory_checkpointing_enabled:
|
||||
m = checkpoint(self.multiplexer, *att_in)
|
||||
else:
|
||||
m = self.multiplexer(*att_in)
|
||||
|
||||
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
||||
outputs, attention = self.switch(xformed, m, True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user