Add teco_resgen

This commit is contained in:
James Betker 2020-10-27 20:59:55 -06:00
parent 00bb568956
commit 9848f4c6cb
4 changed files with 71 additions and 1 deletions

View File

@ -0,0 +1,60 @@
import torch
import torch.nn as nn
from utils.util import sequential_checkpoint
from models.archs.arch_util import ConvGnSilu, make_layer
class TecoResblock(nn.Module):
def __init__(self, nf):
self.nf = nf
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1)
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
return identity + x
class TecoUpconv(nn.Module):
def __init__(self, nf, scale):
self.nf = nf
self.scale = scale
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.conv3 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
x = self.conv3(x)
return identity + self.final_conv(x)
# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper.
# Main differences:
# - Uses SiLU instead of ReLU
# - Reference input is in HR space (just makes more sense)
# - Doesn't use transposed convolutions - just uses interpolation instead.
# - Upsample block is slightly more complicated.
class TecoGen(nn.Module):
def __init__(self, nf, scale):
self.nf = nf
self.scale = scale
fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True)
res_layers = [TecoResblock(nf) for i in range(15)]
upsample = TecoUpconv(nf)
everything = [fea_conv] + res_layers + upsample
self.core = nn.Sequential(*everything)
def forward(self, x, ref=None):
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="bicubic")
if ref is None:
ref = torch.zeros_like(x)
join = torch.cat([x, ref], dim=1)
return sequential_checkpoint(self.core, 6, join)

View File

@ -18,6 +18,7 @@ import models.archs.feature_arch as feature_arch
import models.archs.panet.panet as panet
import models.archs.rcan as rcan
import models.archs.ChainedEmbeddingGen as chained
from models.archs.teco_resgen import TecoGen
logger = logging.getLogger('base')
@ -98,6 +99,8 @@ def define_G(opt, net_key='network_G', scale=None):
netG = SwitchedGen_arch.BackboneSpinenetNoHead()
elif which_model == "backbone_resnet":
netG = SwitchedGen_arch.BackboneResnet()
elif which_model == "tecogen":
netG = TecoGen(opt_net['nf'], opt_net['scale'])
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))

View File

@ -278,7 +278,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_multifaceted_chained4x.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)

View File

@ -51,6 +51,13 @@ def checkpoint(fn, *args):
else:
return fn(*args)
def sequential_checkpoint(fn, partitions, *args):
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
if enabled:
return torch.utils.checkpoint.checkpoint_sequential(fn, partitions, *args)
else:
return fn(*args)
# A fancy alternative to if <flag> checkpoint() else <call>
def possible_checkpoint(enabled, fn, *args):
opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True