Add teco_resgen
This commit is contained in:
parent
00bb568956
commit
9848f4c6cb
60
codes/models/archs/teco_resgen.py
Normal file
60
codes/models/archs/teco_resgen.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from utils.util import sequential_checkpoint
|
||||
from models.archs.arch_util import ConvGnSilu, make_layer
|
||||
|
||||
|
||||
class TecoResblock(nn.Module):
|
||||
def __init__(self, nf):
|
||||
self.nf = nf
|
||||
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1)
|
||||
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return identity + x
|
||||
|
||||
|
||||
class TecoUpconv(nn.Module):
|
||||
def __init__(self, nf, scale):
|
||||
self.nf = nf
|
||||
self.scale = scale
|
||||
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.conv3 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
|
||||
x = self.conv3(x)
|
||||
return identity + self.final_conv(x)
|
||||
|
||||
|
||||
# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper.
|
||||
# Main differences:
|
||||
# - Uses SiLU instead of ReLU
|
||||
# - Reference input is in HR space (just makes more sense)
|
||||
# - Doesn't use transposed convolutions - just uses interpolation instead.
|
||||
# - Upsample block is slightly more complicated.
|
||||
class TecoGen(nn.Module):
|
||||
def __init__(self, nf, scale):
|
||||
self.nf = nf
|
||||
self.scale = scale
|
||||
fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True)
|
||||
res_layers = [TecoResblock(nf) for i in range(15)]
|
||||
upsample = TecoUpconv(nf)
|
||||
everything = [fea_conv] + res_layers + upsample
|
||||
self.core = nn.Sequential(*everything)
|
||||
|
||||
def forward(self, x, ref=None):
|
||||
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
||||
if ref is None:
|
||||
ref = torch.zeros_like(x)
|
||||
join = torch.cat([x, ref], dim=1)
|
||||
return sequential_checkpoint(self.core, 6, join)
|
||||
|
|
@ -18,6 +18,7 @@ import models.archs.feature_arch as feature_arch
|
|||
import models.archs.panet.panet as panet
|
||||
import models.archs.rcan as rcan
|
||||
import models.archs.ChainedEmbeddingGen as chained
|
||||
from models.archs.teco_resgen import TecoGen
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -98,6 +99,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
netG = SwitchedGen_arch.BackboneSpinenetNoHead()
|
||||
elif which_model == "backbone_resnet":
|
||||
netG = SwitchedGen_arch.BackboneResnet()
|
||||
elif which_model == "tecogen":
|
||||
netG = TecoGen(opt_net['nf'], opt_net['scale'])
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
|
||||
|
|
|
@ -278,7 +278,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_multifaceted_chained4x.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -51,6 +51,13 @@ def checkpoint(fn, *args):
|
|||
else:
|
||||
return fn(*args)
|
||||
|
||||
def sequential_checkpoint(fn, partitions, *args):
|
||||
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||
if enabled:
|
||||
return torch.utils.checkpoint.checkpoint_sequential(fn, partitions, *args)
|
||||
else:
|
||||
return fn(*args)
|
||||
|
||||
# A fancy alternative to if <flag> checkpoint() else <call>
|
||||
def possible_checkpoint(enabled, fn, *args):
|
||||
opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||
|
|
Loading…
Reference in New Issue
Block a user