diff --git a/codes/models/archs/teco_resgen.py b/codes/models/archs/teco_resgen.py new file mode 100644 index 00000000..85d46502 --- /dev/null +++ b/codes/models/archs/teco_resgen.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +from utils.util import sequential_checkpoint +from models.archs.arch_util import ConvGnSilu, make_layer + + +class TecoResblock(nn.Module): + def __init__(self, nf): + self.nf = nf + self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1) + self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + return identity + x + + +class TecoUpconv(nn.Module): + def __init__(self, nf, scale): + self.nf = nf + self.scale = scale + self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.conv3 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest") + x = self.conv3(x) + return identity + self.final_conv(x) + + +# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper. +# Main differences: +# - Uses SiLU instead of ReLU +# - Reference input is in HR space (just makes more sense) +# - Doesn't use transposed convolutions - just uses interpolation instead. +# - Upsample block is slightly more complicated. +class TecoGen(nn.Module): + def __init__(self, nf, scale): + self.nf = nf + self.scale = scale + fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True) + res_layers = [TecoResblock(nf) for i in range(15)] + upsample = TecoUpconv(nf) + everything = [fea_conv] + res_layers + upsample + self.core = nn.Sequential(*everything) + + def forward(self, x, ref=None): + x = nn.functional.interpolate(x, scale_factor=self.scale, mode="bicubic") + if ref is None: + ref = torch.zeros_like(x) + join = torch.cat([x, ref], dim=1) + return sequential_checkpoint(self.core, 6, join) + diff --git a/codes/models/networks.py b/codes/models/networks.py index 10da1cd9..e4a4e34f 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -18,6 +18,7 @@ import models.archs.feature_arch as feature_arch import models.archs.panet.panet as panet import models.archs.rcan as rcan import models.archs.ChainedEmbeddingGen as chained +from models.archs.teco_resgen import TecoGen logger = logging.getLogger('base') @@ -98,6 +99,8 @@ def define_G(opt, net_key='network_G', scale=None): netG = SwitchedGen_arch.BackboneSpinenetNoHead() elif which_model == "backbone_resnet": netG = SwitchedGen_arch.BackboneResnet() + elif which_model == "tecogen": + netG = TecoGen(opt_net['nf'], opt_net['scale']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) diff --git a/codes/train2.py b/codes/train2.py index 81bd7c35..f9c94a82 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -278,7 +278,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_multifaceted_chained4x.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/utils/util.py b/codes/utils/util.py index ea271b2e..581e6307 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -51,6 +51,13 @@ def checkpoint(fn, *args): else: return fn(*args) +def sequential_checkpoint(fn, partitions, *args): + enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True + if enabled: + return torch.utils.checkpoint.checkpoint_sequential(fn, partitions, *args) + else: + return fn(*args) + # A fancy alternative to if checkpoint() else def possible_checkpoint(enabled, fn, *args): opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True