Allow checkpointing to be disabled in the options file

Also makes options a global variable for usage in utils.
This commit is contained in:
James Betker 2020-10-03 11:03:28 -06:00
parent dd9d7b27ac
commit 19a4075e1e
9 changed files with 20 additions and 7 deletions

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
import models.archs.arch_util as arch_util
from models.archs.arch_util import PixelUnshuffle
import torchvision
from torch.utils.checkpoint import checkpoint
from utils.util import checkpoint
class ResidualDenseBlock_5C(nn.Module):

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.checkpoint import checkpoint
from utils.util import checkpoint
from models.archs import SPSR_util as B
from models.archs.SwitchedResidualGenerator_arch import ConfigurableSwitchComputer, ReferenceImageBranch, \

View File

@ -10,7 +10,7 @@ from switched_conv_util import save_attention_to_image_rgb
from switched_conv import compute_attention_specificity
import os
import torchvision
from torch.utils.checkpoint import checkpoint
from utils.util import checkpoint
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
# Doubles the input filter count.

View File

@ -7,11 +7,11 @@ from collections import OrderedDict
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, SiLU
from switched_conv_util import save_attention_to_image_rgb
import os
from torch.utils.checkpoint import checkpoint
from utils.util import checkpoint
from models.archs.spinenet_arch import SpineNet
# Set to true to relieve memory pressure by using torch.utils.checkpoint in several memory-critical locations.
# Set to true to relieve memory pressure by using utils.util in several memory-critical locations.
memory_checkpointing_enabled = True

View File

@ -4,7 +4,7 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from utils.util import checkpoint
from torch.autograd import Variable

View File

@ -1,7 +1,7 @@
import torch.nn
from models.archs.SPSR_arch import ImageGradientNoPadding
from data.weight_scheduler import get_scheduler_for_opt
from torch.utils.checkpoint import checkpoint
from utils.util import checkpoint
import torchvision.utils as utils
#from models.steps.recursive_gen_injectors import ImageFlowInjector

View File

@ -5,6 +5,7 @@ import yaml
from utils.util import OrderedYaml
Loader, Dumper = OrderedYaml()
loaded_options = None
def parse(opt_path, is_train=True):
with open(opt_path, mode='r') as f:

View File

@ -126,6 +126,9 @@ def main():
# torch.backends.cudnn.deterministic = True
# torch.autograd.set_detect_anomaly(True)
# Save the compiled opt dict to the global loaded_options variable.
option.loaded_options = opt
#### create train and val dataloader
dataset_ratio = 1 # enlarge the size of each epoch
for phase, dataset_opt in opt['datasets'].items():

View File

@ -14,6 +14,8 @@ from torchvision.utils import make_grid
from shutil import get_terminal_size
import scp
import paramiko
import options.options as options
from utils.util import checkpoint
import yaml
try:
@ -41,6 +43,13 @@ def OrderedYaml():
# miscellaneous
####################
# Conditionally uses torch's checkpoint functionality if it is enabled in the opt file.
def checkpoint(fn, *args):
enabled = options.loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in options.loaded_options.keys() else True
if enabled:
return checkpoint(fn, *args)
else:
return fn(*args)
def get_timestamp():
return datetime.now().strftime('%y%m%d-%H%M%S')