From dffbfd2ec40bfb4c38c7ffd7ee3129f6bd227781 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 8 Sep 2020 15:14:43 -0600 Subject: [PATCH] Allow SRG checkpointing to be toggled --- .../archs/SwitchedResidualGenerator_arch.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 0832bb0e..e6f07544 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -7,6 +7,11 @@ from collections import OrderedDict from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu from switched_conv_util import save_attention_to_image_rgb import os +from torch.utils.checkpoint import checkpoint + + +# Set to true to relieve memory pressure by using torch.utils.checkpoint in several memory-critical locations. +memory_checkpointing_enabled = False # VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation @@ -219,11 +224,17 @@ class ConfigurableSwitchComputer(nn.Module): x = self.pre_transform(*x) if not isinstance(x, tuple): x = (x,) - xformed = [torch.utils.checkpoint.checkpoint(t, *x) for t in self.transforms] + if memory_checkpointing_enabled: + xformed = [checkpoint(t, *x) for t in self.transforms] + else: + xformed = [t(*x) for t in self.transforms] if not isinstance(att_in, tuple): att_in = (att_in,) - m = torch.utils.checkpoint.checkpoint(self.multiplexer, *att_in) + if memory_checkpointing_enabled: + m = checkpoint(self.multiplexer, *att_in) + else: + m = self.multiplexer(*att_in) # It is assumed that [xformed] and [m] are collapsed into tensors at this point. outputs, attention = self.switch(xformed, m, True)