forked from mrq/DL-Art-School
More work in support of training flow networks in tandem with generators
This commit is contained in:
parent
c21088e238
commit
df47d6cbbb
|
@ -194,6 +194,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
net_enabled = name in nets_to_train
|
net_enabled = name in nets_to_train
|
||||||
if net_enabled:
|
if net_enabled:
|
||||||
enabled += 1
|
enabled += 1
|
||||||
|
# Networks can opt out of training before a certain iteration by declaring 'after' in their definition.
|
||||||
|
if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']:
|
||||||
|
net_enabled = False
|
||||||
for p in net.parameters():
|
for p in net.parameters():
|
||||||
if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"):
|
if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"):
|
||||||
p.requires_grad = net_enabled
|
p.requires_grad = net_enabled
|
||||||
|
@ -225,7 +228,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
# And finally perform optimization.
|
# And finally perform optimization.
|
||||||
[e.before_optimize(state) for e in self.experiments]
|
[e.before_optimize(state) for e in self.experiments]
|
||||||
s.do_step()
|
s.do_step(step)
|
||||||
[e.after_optimize(state) for e in self.experiments]
|
[e.after_optimize(state) for e in self.experiments]
|
||||||
|
|
||||||
# Record visual outputs for usage in debugging and testing.
|
# Record visual outputs for usage in debugging and testing.
|
||||||
|
|
98
codes/models/archs/pyramid_arch.py
Normal file
98
codes/models/archs/pyramid_arch.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ExpansionBlock
|
||||||
|
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||||
|
from utils.util import checkpoint
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Pyramid(nn.Module):
|
||||||
|
def __init__(self, nf, depth, processing_convs_per_layer, processing_at_point, scale_per_level=2, block=ConvGnLelu,
|
||||||
|
norm=True, return_outlevels=False):
|
||||||
|
super(Pyramid, self).__init__()
|
||||||
|
levels = []
|
||||||
|
current_filters = nf
|
||||||
|
self.return_outlevels = return_outlevels
|
||||||
|
for d in range(depth):
|
||||||
|
level = [block(current_filters, int(current_filters*scale_per_level), kernel_size=3, stride=2, activation=True, norm=False, bias=False)]
|
||||||
|
current_filters = int(current_filters*scale_per_level)
|
||||||
|
for pc in range(processing_convs_per_layer):
|
||||||
|
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
||||||
|
levels.append(nn.Sequential(*level))
|
||||||
|
self.downsamples = nn.ModuleList(levels)
|
||||||
|
if processing_at_point > 0:
|
||||||
|
point_processor = []
|
||||||
|
for p in range(processing_at_point):
|
||||||
|
point_processor.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
||||||
|
self.point_processor = nn.Sequential(*point_processor)
|
||||||
|
else:
|
||||||
|
self.point_processor = None
|
||||||
|
levels = []
|
||||||
|
for d in range(depth):
|
||||||
|
level = [ExpansionBlock(current_filters, int(current_filters / scale_per_level), block=block)]
|
||||||
|
current_filters = int(current_filters / scale_per_level)
|
||||||
|
for pc in range(processing_convs_per_layer):
|
||||||
|
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
||||||
|
levels.append(nn.ModuleList(level))
|
||||||
|
self.upsamples = nn.ModuleList(levels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
passthroughs = []
|
||||||
|
fea = x
|
||||||
|
for lvl in self.downsamples:
|
||||||
|
passthroughs.append(fea)
|
||||||
|
fea = lvl(fea)
|
||||||
|
out_levels = []
|
||||||
|
fea = self.point_processor(fea)
|
||||||
|
for i, lvl in enumerate(self.upsamples):
|
||||||
|
out_levels.append(fea)
|
||||||
|
for j, sublvl in enumerate(lvl):
|
||||||
|
if j == 0:
|
||||||
|
fea = sublvl(fea, passthroughs[-1-i])
|
||||||
|
else:
|
||||||
|
fea = sublvl(fea)
|
||||||
|
|
||||||
|
out_levels.append(fea)
|
||||||
|
|
||||||
|
if self.return_outlevels:
|
||||||
|
return tuple(out_levels)
|
||||||
|
else:
|
||||||
|
return fea
|
||||||
|
|
||||||
|
|
||||||
|
class BasicResamplingFlowNet(nn.Module):
|
||||||
|
def create_termini(self, filters):
|
||||||
|
return nn.Sequential(ConvGnLelu(int(filters), 2, kernel_size=3, activation=False, norm=False, bias=True),
|
||||||
|
nn.Tanh())
|
||||||
|
|
||||||
|
def __init__(self, nf, resample_scale=1):
|
||||||
|
super(BasicResamplingFlowNet, self).__init__()
|
||||||
|
self.initial_conv = ConvGnLelu(6, nf, kernel_size=7, activation=False, norm=False, bias=True)
|
||||||
|
self.pyramid = Pyramid(nf, 3, 0, 1, 1.5, return_outlevels=True)
|
||||||
|
self.termini = nn.ModuleList([self.create_termini(nf*1.5**3),
|
||||||
|
self.create_termini(nf*1.5**2),
|
||||||
|
self.create_termini(nf*1.5)])
|
||||||
|
self.terminus = nn.Sequential(ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=True),
|
||||||
|
ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False),
|
||||||
|
ConvGnLelu(nf, nf//2, kernel_size=3, activation=False, norm=False, bias=True),
|
||||||
|
ConvGnLelu(nf//2, 2, kernel_size=3, activation=False, norm=False, bias=True),
|
||||||
|
nn.Tanh())
|
||||||
|
self.scale = resample_scale
|
||||||
|
self.resampler = Resample2d()
|
||||||
|
|
||||||
|
def forward(self, left, right):
|
||||||
|
fea = self.initial_conv(torch.cat([left, right], dim=1))
|
||||||
|
levels = checkpoint(self.pyramid, fea)
|
||||||
|
flos = []
|
||||||
|
compares = []
|
||||||
|
for i, level in enumerate(levels):
|
||||||
|
if i == 3:
|
||||||
|
flow = checkpoint(self.terminus, level) * self.scale
|
||||||
|
else:
|
||||||
|
flow = self.termini[i](level) * self.scale
|
||||||
|
img_scale = 1/2**(3-i)
|
||||||
|
flos.append(self.resampler(F.interpolate(left, scale_factor=img_scale, mode="area").float(), flow.float()))
|
||||||
|
compares.append(F.interpolate(right, scale_factor=img_scale, mode="area"))
|
||||||
|
flos_structural_var = torch.var(flos[-1], dim=[-1,-2])
|
||||||
|
return flos, compares, flos_structural_var
|
|
@ -19,6 +19,7 @@ import models.archs.panet.panet as panet
|
||||||
import models.archs.rcan as rcan
|
import models.archs.rcan as rcan
|
||||||
import models.archs.ChainedEmbeddingGen as chained
|
import models.archs.ChainedEmbeddingGen as chained
|
||||||
from models.archs import srg2_classic
|
from models.archs import srg2_classic
|
||||||
|
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
||||||
from models.archs.teco_resgen import TecoGen
|
from models.archs.teco_resgen import TecoGen
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
@ -115,9 +116,10 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
netG = SwitchedGen_arch.BackboneResnet()
|
netG = SwitchedGen_arch.BackboneResnet()
|
||||||
elif which_model == "tecogen":
|
elif which_model == "tecogen":
|
||||||
netG = TecoGen(opt_net['nf'], opt_net['scale'])
|
netG = TecoGen(opt_net['nf'], opt_net['scale'])
|
||||||
|
elif which_model == "basic_resampling_flow_predictor":
|
||||||
|
netG = BasicResamplingFlowNet(opt_net['nf'], resample_scale=opt_net['resample_scale'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
|
|
||||||
return netG
|
return netG
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -66,8 +66,9 @@ class ConfigurableStep(Module):
|
||||||
else:
|
else:
|
||||||
opt_configs = [self.step_opt['optimizer_params']]
|
opt_configs = [self.step_opt['optimizer_params']]
|
||||||
nets = [self.training_net]
|
nets = [self.training_net]
|
||||||
|
training = [training]
|
||||||
self.optimizers = []
|
self.optimizers = []
|
||||||
for net, opt_config in zip(nets, opt_configs):
|
for net_name, net, opt_config in zip(training, nets, opt_configs):
|
||||||
optim_params = []
|
optim_params = []
|
||||||
for k, v in net.named_parameters(): # can optimize for a part of the model
|
for k, v in net.named_parameters(): # can optimize for a part of the model
|
||||||
if v.requires_grad:
|
if v.requires_grad:
|
||||||
|
@ -84,6 +85,7 @@ class ConfigurableStep(Module):
|
||||||
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
|
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
|
||||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||||
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
|
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
|
||||||
|
opt._config['network'] = net_name
|
||||||
self.optimizers.append(opt)
|
self.optimizers.append(opt)
|
||||||
|
|
||||||
# Returns all optimizers used in this step.
|
# Returns all optimizers used in this step.
|
||||||
|
@ -189,13 +191,15 @@ class ConfigurableStep(Module):
|
||||||
|
|
||||||
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
||||||
# all self.optimizers.
|
# all self.optimizers.
|
||||||
def do_step(self):
|
def do_step(self, step):
|
||||||
if not self.grads_generated:
|
if not self.grads_generated:
|
||||||
return
|
return
|
||||||
self.grads_generated = False
|
self.grads_generated = False
|
||||||
for opt in self.optimizers:
|
for opt in self.optimizers:
|
||||||
# Optimizers can be opted out in the early stages of training.
|
# Optimizers can be opted out in the early stages of training.
|
||||||
after = opt._config['after'] if 'after' in opt._config.keys() else 0
|
after = opt._config['after'] if 'after' in opt._config.keys() else 0
|
||||||
|
after_network = self.opt['networks'][opt._config['network']]['after'] if 'after' in self.opt['networks'][opt._config['network']].keys() else 0
|
||||||
|
after = max(after, after_network)
|
||||||
if self.env['step'] < after:
|
if self.env['step'] < after:
|
||||||
continue
|
continue
|
||||||
before = opt._config['before'] if 'before' in opt._config.keys() else -1
|
before = opt._config['before'] if 'before' in opt._config.keys() else -1
|
||||||
|
|
|
@ -73,15 +73,21 @@ class FfmpegBackedVideoDataset(data.Dataset):
|
||||||
|
|
||||||
if self.force_multiple > 1:
|
if self.force_multiple > 1:
|
||||||
assert self.vertical_splits <= 1 # This is not compatible with vertical splits for now.
|
assert self.vertical_splits <= 1 # This is not compatible with vertical splits for now.
|
||||||
_, h, w = img_LQ.shape
|
c, h, w = img_LQ.shape
|
||||||
|
h_, w_ = h, w
|
||||||
height_removed = h % self.force_multiple
|
height_removed = h % self.force_multiple
|
||||||
width_removed = w % self.force_multiple
|
width_removed = w % self.force_multiple
|
||||||
if height_removed != 0:
|
if height_removed != 0:
|
||||||
img_LQ = img_LQ[:, :-height_removed, :]
|
h_ = self.force_multiple * ((h // self.force_multiple) + 1)
|
||||||
ref = ref[:, :-height_removed, :]
|
|
||||||
if width_removed != 0:
|
if width_removed != 0:
|
||||||
img_LQ = img_LQ[:, :, :-width_removed]
|
w_ = self.force_multiple * ((w // self.force_multiple) + 1)
|
||||||
ref = ref[:, :, :-width_removed]
|
lq_template = torch.zeros(c,h_,w_)
|
||||||
|
lq_template[:,:h,:w] = img_LQ
|
||||||
|
ref_template = torch.zeros(c,h_,w_)
|
||||||
|
ref_template[:,:h,:w] = img_LQ
|
||||||
|
img_LQ = lq_template
|
||||||
|
ref = ref_template
|
||||||
|
|
||||||
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
|
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
|
||||||
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
|
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
|
||||||
|
|
||||||
|
|
|
@ -278,7 +278,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_artificial_quality.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_with_flow.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user