diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 65bcc569..0133f68b 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -5,7 +5,7 @@ import torch.nn.functional as F import models.archs.arch_util as arch_util from models.archs.arch_util import PixelUnshuffle import torchvision -from torch.utils.checkpoint import checkpoint +from utils.util import checkpoint class ResidualDenseBlock_5C(nn.Module): diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index ac6db451..052bb75d 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchvision -from torch.utils.checkpoint import checkpoint +from utils.util import checkpoint from models.archs import SPSR_util as B from models.archs.SwitchedResidualGenerator_arch import ConfigurableSwitchComputer, ReferenceImageBranch, \ diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index bd41fb70..d92e6fd7 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -10,7 +10,7 @@ from switched_conv_util import save_attention_to_image_rgb from switched_conv import compute_attention_specificity import os import torchvision -from torch.utils.checkpoint import checkpoint +from utils.util import checkpoint # VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation # Doubles the input filter count. diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 2f42b2e1..7b6edfc8 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -7,11 +7,11 @@ from collections import OrderedDict from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, SiLU from switched_conv_util import save_attention_to_image_rgb import os -from torch.utils.checkpoint import checkpoint +from utils.util import checkpoint from models.archs.spinenet_arch import SpineNet -# Set to true to relieve memory pressure by using torch.utils.checkpoint in several memory-critical locations. +# Set to true to relieve memory pressure by using utils.util in several memory-critical locations. memory_checkpointing_enabled = True diff --git a/codes/models/archs/rcan.py b/codes/models/archs/rcan.py index 71d6955c..684727db 100644 --- a/codes/models/archs/rcan.py +++ b/codes/models/archs/rcan.py @@ -4,7 +4,7 @@ import math import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.checkpoint import checkpoint +from utils.util import checkpoint from torch.autograd import Variable diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 421d44cc..f54a81f6 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -1,7 +1,7 @@ import torch.nn from models.archs.SPSR_arch import ImageGradientNoPadding from data.weight_scheduler import get_scheduler_for_opt -from torch.utils.checkpoint import checkpoint +from utils.util import checkpoint import torchvision.utils as utils #from models.steps.recursive_gen_injectors import ImageFlowInjector diff --git a/codes/options/options.py b/codes/options/options.py index 297426f9..be42de2f 100644 --- a/codes/options/options.py +++ b/codes/options/options.py @@ -5,6 +5,7 @@ import yaml from utils.util import OrderedYaml Loader, Dumper = OrderedYaml() +loaded_options = None def parse(opt_path, is_train=True): with open(opt_path, mode='r') as f: diff --git a/codes/train.py b/codes/train.py index eef9f237..af709a3f 100644 --- a/codes/train.py +++ b/codes/train.py @@ -126,6 +126,9 @@ def main(): # torch.backends.cudnn.deterministic = True # torch.autograd.set_detect_anomaly(True) + # Save the compiled opt dict to the global loaded_options variable. + option.loaded_options = opt + #### create train and val dataloader dataset_ratio = 1 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): diff --git a/codes/utils/util.py b/codes/utils/util.py index c663e5d9..2ae60bb2 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -14,6 +14,8 @@ from torchvision.utils import make_grid from shutil import get_terminal_size import scp import paramiko +import options.options as options +from utils.util import checkpoint import yaml try: @@ -41,6 +43,13 @@ def OrderedYaml(): # miscellaneous #################### +# Conditionally uses torch's checkpoint functionality if it is enabled in the opt file. +def checkpoint(fn, *args): + enabled = options.loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in options.loaded_options.keys() else True + if enabled: + return checkpoint(fn, *args) + else: + return fn(*args) def get_timestamp(): return datetime.now().strftime('%y%m%d-%H%M%S')