forked from mrq/DL-Art-School
Allow checkpointing to be disabled in the options file
Also makes options a global variable for usage in utils.
This commit is contained in:
parent
dd9d7b27ac
commit
19a4075e1e
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||||
import models.archs.arch_util as arch_util
|
import models.archs.arch_util as arch_util
|
||||||
from models.archs.arch_util import PixelUnshuffle
|
from models.archs.arch_util import PixelUnshuffle
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch.utils.checkpoint import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
class ResidualDenseBlock_5C(nn.Module):
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch.utils.checkpoint import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
from models.archs import SPSR_util as B
|
from models.archs import SPSR_util as B
|
||||||
from models.archs.SwitchedResidualGenerator_arch import ConfigurableSwitchComputer, ReferenceImageBranch, \
|
from models.archs.SwitchedResidualGenerator_arch import ConfigurableSwitchComputer, ReferenceImageBranch, \
|
||||||
|
|
|
@ -10,7 +10,7 @@ from switched_conv_util import save_attention_to_image_rgb
|
||||||
from switched_conv import compute_attention_specificity
|
from switched_conv import compute_attention_specificity
|
||||||
import os
|
import os
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch.utils.checkpoint import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
||||||
# Doubles the input filter count.
|
# Doubles the input filter count.
|
||||||
|
|
|
@ -7,11 +7,11 @@ from collections import OrderedDict
|
||||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, SiLU
|
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, SiLU
|
||||||
from switched_conv_util import save_attention_to_image_rgb
|
from switched_conv_util import save_attention_to_image_rgb
|
||||||
import os
|
import os
|
||||||
from torch.utils.checkpoint import checkpoint
|
from utils.util import checkpoint
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.archs.spinenet_arch import SpineNet
|
||||||
|
|
||||||
|
|
||||||
# Set to true to relieve memory pressure by using torch.utils.checkpoint in several memory-critical locations.
|
# Set to true to relieve memory pressure by using utils.util in several memory-critical locations.
|
||||||
memory_checkpointing_enabled = True
|
memory_checkpointing_enabled = True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.checkpoint import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||||
from data.weight_scheduler import get_scheduler_for_opt
|
from data.weight_scheduler import get_scheduler_for_opt
|
||||||
from torch.utils.checkpoint import checkpoint
|
from utils.util import checkpoint
|
||||||
import torchvision.utils as utils
|
import torchvision.utils as utils
|
||||||
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import yaml
|
||||||
from utils.util import OrderedYaml
|
from utils.util import OrderedYaml
|
||||||
Loader, Dumper = OrderedYaml()
|
Loader, Dumper = OrderedYaml()
|
||||||
|
|
||||||
|
loaded_options = None
|
||||||
|
|
||||||
def parse(opt_path, is_train=True):
|
def parse(opt_path, is_train=True):
|
||||||
with open(opt_path, mode='r') as f:
|
with open(opt_path, mode='r') as f:
|
||||||
|
|
|
@ -126,6 +126,9 @@ def main():
|
||||||
# torch.backends.cudnn.deterministic = True
|
# torch.backends.cudnn.deterministic = True
|
||||||
# torch.autograd.set_detect_anomaly(True)
|
# torch.autograd.set_detect_anomaly(True)
|
||||||
|
|
||||||
|
# Save the compiled opt dict to the global loaded_options variable.
|
||||||
|
option.loaded_options = opt
|
||||||
|
|
||||||
#### create train and val dataloader
|
#### create train and val dataloader
|
||||||
dataset_ratio = 1 # enlarge the size of each epoch
|
dataset_ratio = 1 # enlarge the size of each epoch
|
||||||
for phase, dataset_opt in opt['datasets'].items():
|
for phase, dataset_opt in opt['datasets'].items():
|
||||||
|
|
|
@ -14,6 +14,8 @@ from torchvision.utils import make_grid
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
import scp
|
import scp
|
||||||
import paramiko
|
import paramiko
|
||||||
|
import options.options as options
|
||||||
|
from utils.util import checkpoint
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
try:
|
try:
|
||||||
|
@ -41,6 +43,13 @@ def OrderedYaml():
|
||||||
# miscellaneous
|
# miscellaneous
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
# Conditionally uses torch's checkpoint functionality if it is enabled in the opt file.
|
||||||
|
def checkpoint(fn, *args):
|
||||||
|
enabled = options.loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in options.loaded_options.keys() else True
|
||||||
|
if enabled:
|
||||||
|
return checkpoint(fn, *args)
|
||||||
|
else:
|
||||||
|
return fn(*args)
|
||||||
|
|
||||||
def get_timestamp():
|
def get_timestamp():
|
||||||
return datetime.now().strftime('%y%m%d-%H%M%S')
|
return datetime.now().strftime('%y%m%d-%H%M%S')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user