forked from mrq/DL-Art-School
Add teco_resgen
This commit is contained in:
parent
00bb568956
commit
9848f4c6cb
60
codes/models/archs/teco_resgen.py
Normal file
60
codes/models/archs/teco_resgen.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from utils.util import sequential_checkpoint
|
||||||
|
from models.archs.arch_util import ConvGnSilu, make_layer
|
||||||
|
|
||||||
|
|
||||||
|
class TecoResblock(nn.Module):
|
||||||
|
def __init__(self, nf):
|
||||||
|
self.nf = nf
|
||||||
|
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1)
|
||||||
|
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
return identity + x
|
||||||
|
|
||||||
|
|
||||||
|
class TecoUpconv(nn.Module):
|
||||||
|
def __init__(self, nf, scale):
|
||||||
|
self.nf = nf
|
||||||
|
self.scale = scale
|
||||||
|
self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||||
|
self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||||
|
self.conv3 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||||
|
self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest")
|
||||||
|
x = self.conv3(x)
|
||||||
|
return identity + self.final_conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper.
|
||||||
|
# Main differences:
|
||||||
|
# - Uses SiLU instead of ReLU
|
||||||
|
# - Reference input is in HR space (just makes more sense)
|
||||||
|
# - Doesn't use transposed convolutions - just uses interpolation instead.
|
||||||
|
# - Upsample block is slightly more complicated.
|
||||||
|
class TecoGen(nn.Module):
|
||||||
|
def __init__(self, nf, scale):
|
||||||
|
self.nf = nf
|
||||||
|
self.scale = scale
|
||||||
|
fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True)
|
||||||
|
res_layers = [TecoResblock(nf) for i in range(15)]
|
||||||
|
upsample = TecoUpconv(nf)
|
||||||
|
everything = [fea_conv] + res_layers + upsample
|
||||||
|
self.core = nn.Sequential(*everything)
|
||||||
|
|
||||||
|
def forward(self, x, ref=None):
|
||||||
|
x = nn.functional.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
||||||
|
if ref is None:
|
||||||
|
ref = torch.zeros_like(x)
|
||||||
|
join = torch.cat([x, ref], dim=1)
|
||||||
|
return sequential_checkpoint(self.core, 6, join)
|
||||||
|
|
|
@ -18,6 +18,7 @@ import models.archs.feature_arch as feature_arch
|
||||||
import models.archs.panet.panet as panet
|
import models.archs.panet.panet as panet
|
||||||
import models.archs.rcan as rcan
|
import models.archs.rcan as rcan
|
||||||
import models.archs.ChainedEmbeddingGen as chained
|
import models.archs.ChainedEmbeddingGen as chained
|
||||||
|
from models.archs.teco_resgen import TecoGen
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
@ -98,6 +99,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
netG = SwitchedGen_arch.BackboneSpinenetNoHead()
|
netG = SwitchedGen_arch.BackboneSpinenetNoHead()
|
||||||
elif which_model == "backbone_resnet":
|
elif which_model == "backbone_resnet":
|
||||||
netG = SwitchedGen_arch.BackboneResnet()
|
netG = SwitchedGen_arch.BackboneResnet()
|
||||||
|
elif which_model == "tecogen":
|
||||||
|
netG = TecoGen(opt_net['nf'], opt_net['scale'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
|
|
||||||
|
|
|
@ -278,7 +278,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_multifaceted_chained4x.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
|
@ -51,6 +51,13 @@ def checkpoint(fn, *args):
|
||||||
else:
|
else:
|
||||||
return fn(*args)
|
return fn(*args)
|
||||||
|
|
||||||
|
def sequential_checkpoint(fn, partitions, *args):
|
||||||
|
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||||
|
if enabled:
|
||||||
|
return torch.utils.checkpoint.checkpoint_sequential(fn, partitions, *args)
|
||||||
|
else:
|
||||||
|
return fn(*args)
|
||||||
|
|
||||||
# A fancy alternative to if <flag> checkpoint() else <call>
|
# A fancy alternative to if <flag> checkpoint() else <call>
|
||||||
def possible_checkpoint(enabled, fn, *args):
|
def possible_checkpoint(enabled, fn, *args):
|
||||||
opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||||
|
|
Loading…
Reference in New Issue
Block a user