Allow checkpointing to be disabled in the options file
Also makes options a global variable for usage in utils.
This commit is contained in:
parent
dd9d7b27ac
commit
19a4075e1e
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
import models.archs.arch_util as arch_util
|
||||
from models.archs.arch_util import PixelUnshuffle
|
||||
import torchvision
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from utils.util import checkpoint
|
||||
|
||||
from models.archs import SPSR_util as B
|
||||
from models.archs.SwitchedResidualGenerator_arch import ConfigurableSwitchComputer, ReferenceImageBranch, \
|
||||
|
|
|
@ -10,7 +10,7 @@ from switched_conv_util import save_attention_to_image_rgb
|
|||
from switched_conv import compute_attention_specificity
|
||||
import os
|
||||
import torchvision
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from utils.util import checkpoint
|
||||
|
||||
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
||||
# Doubles the input filter count.
|
||||
|
|
|
@ -7,11 +7,11 @@ from collections import OrderedDict
|
|||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, SiLU
|
||||
from switched_conv_util import save_attention_to_image_rgb
|
||||
import os
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from utils.util import checkpoint
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
|
||||
|
||||
# Set to true to relieve memory pressure by using torch.utils.checkpoint in several memory-critical locations.
|
||||
# Set to true to relieve memory pressure by using utils.util in several memory-critical locations.
|
||||
memory_checkpointing_enabled = True
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from utils.util import checkpoint
|
||||
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch.nn
|
||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||
from data.weight_scheduler import get_scheduler_for_opt
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from utils.util import checkpoint
|
||||
import torchvision.utils as utils
|
||||
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import yaml
|
|||
from utils.util import OrderedYaml
|
||||
Loader, Dumper = OrderedYaml()
|
||||
|
||||
loaded_options = None
|
||||
|
||||
def parse(opt_path, is_train=True):
|
||||
with open(opt_path, mode='r') as f:
|
||||
|
|
|
@ -126,6 +126,9 @@ def main():
|
|||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
# Save the compiled opt dict to the global loaded_options variable.
|
||||
option.loaded_options = opt
|
||||
|
||||
#### create train and val dataloader
|
||||
dataset_ratio = 1 # enlarge the size of each epoch
|
||||
for phase, dataset_opt in opt['datasets'].items():
|
||||
|
|
|
@ -14,6 +14,8 @@ from torchvision.utils import make_grid
|
|||
from shutil import get_terminal_size
|
||||
import scp
|
||||
import paramiko
|
||||
import options.options as options
|
||||
from utils.util import checkpoint
|
||||
|
||||
import yaml
|
||||
try:
|
||||
|
@ -41,6 +43,13 @@ def OrderedYaml():
|
|||
# miscellaneous
|
||||
####################
|
||||
|
||||
# Conditionally uses torch's checkpoint functionality if it is enabled in the opt file.
|
||||
def checkpoint(fn, *args):
|
||||
enabled = options.loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in options.loaded_options.keys() else True
|
||||
if enabled:
|
||||
return checkpoint(fn, *args)
|
||||
else:
|
||||
return fn(*args)
|
||||
|
||||
def get_timestamp():
|
||||
return datetime.now().strftime('%y%m%d-%H%M%S')
|
||||
|
|
Loading…
Reference in New Issue
Block a user