More work in support of training flow networks in tandem with generators

This commit is contained in:
James Betker 2020-11-04 18:07:48 -07:00
parent c21088e238
commit df47d6cbbb
6 changed files with 123 additions and 10 deletions

View File

@ -194,6 +194,9 @@ class ExtensibleTrainer(BaseModel):
net_enabled = name in nets_to_train
if net_enabled:
enabled += 1
# Networks can opt out of training before a certain iteration by declaring 'after' in their definition.
if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']:
net_enabled = False
for p in net.parameters():
if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"):
p.requires_grad = net_enabled
@ -225,7 +228,7 @@ class ExtensibleTrainer(BaseModel):
# And finally perform optimization.
[e.before_optimize(state) for e in self.experiments]
s.do_step()
s.do_step(step)
[e.after_optimize(state) for e in self.experiments]
# Record visual outputs for usage in debugging and testing.

View File

@ -0,0 +1,98 @@
import torch
from torch import nn
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ExpansionBlock
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from utils.util import checkpoint
import torch.nn.functional as F
class Pyramid(nn.Module):
def __init__(self, nf, depth, processing_convs_per_layer, processing_at_point, scale_per_level=2, block=ConvGnLelu,
norm=True, return_outlevels=False):
super(Pyramid, self).__init__()
levels = []
current_filters = nf
self.return_outlevels = return_outlevels
for d in range(depth):
level = [block(current_filters, int(current_filters*scale_per_level), kernel_size=3, stride=2, activation=True, norm=False, bias=False)]
current_filters = int(current_filters*scale_per_level)
for pc in range(processing_convs_per_layer):
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
levels.append(nn.Sequential(*level))
self.downsamples = nn.ModuleList(levels)
if processing_at_point > 0:
point_processor = []
for p in range(processing_at_point):
point_processor.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
self.point_processor = nn.Sequential(*point_processor)
else:
self.point_processor = None
levels = []
for d in range(depth):
level = [ExpansionBlock(current_filters, int(current_filters / scale_per_level), block=block)]
current_filters = int(current_filters / scale_per_level)
for pc in range(processing_convs_per_layer):
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
levels.append(nn.ModuleList(level))
self.upsamples = nn.ModuleList(levels)
def forward(self, x):
passthroughs = []
fea = x
for lvl in self.downsamples:
passthroughs.append(fea)
fea = lvl(fea)
out_levels = []
fea = self.point_processor(fea)
for i, lvl in enumerate(self.upsamples):
out_levels.append(fea)
for j, sublvl in enumerate(lvl):
if j == 0:
fea = sublvl(fea, passthroughs[-1-i])
else:
fea = sublvl(fea)
out_levels.append(fea)
if self.return_outlevels:
return tuple(out_levels)
else:
return fea
class BasicResamplingFlowNet(nn.Module):
def create_termini(self, filters):
return nn.Sequential(ConvGnLelu(int(filters), 2, kernel_size=3, activation=False, norm=False, bias=True),
nn.Tanh())
def __init__(self, nf, resample_scale=1):
super(BasicResamplingFlowNet, self).__init__()
self.initial_conv = ConvGnLelu(6, nf, kernel_size=7, activation=False, norm=False, bias=True)
self.pyramid = Pyramid(nf, 3, 0, 1, 1.5, return_outlevels=True)
self.termini = nn.ModuleList([self.create_termini(nf*1.5**3),
self.create_termini(nf*1.5**2),
self.create_termini(nf*1.5)])
self.terminus = nn.Sequential(ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=True),
ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False),
ConvGnLelu(nf, nf//2, kernel_size=3, activation=False, norm=False, bias=True),
ConvGnLelu(nf//2, 2, kernel_size=3, activation=False, norm=False, bias=True),
nn.Tanh())
self.scale = resample_scale
self.resampler = Resample2d()
def forward(self, left, right):
fea = self.initial_conv(torch.cat([left, right], dim=1))
levels = checkpoint(self.pyramid, fea)
flos = []
compares = []
for i, level in enumerate(levels):
if i == 3:
flow = checkpoint(self.terminus, level) * self.scale
else:
flow = self.termini[i](level) * self.scale
img_scale = 1/2**(3-i)
flos.append(self.resampler(F.interpolate(left, scale_factor=img_scale, mode="area").float(), flow.float()))
compares.append(F.interpolate(right, scale_factor=img_scale, mode="area"))
flos_structural_var = torch.var(flos[-1], dim=[-1,-2])
return flos, compares, flos_structural_var

View File

@ -19,6 +19,7 @@ import models.archs.panet.panet as panet
import models.archs.rcan as rcan
import models.archs.ChainedEmbeddingGen as chained
from models.archs import srg2_classic
from models.archs.pyramid_arch import BasicResamplingFlowNet
from models.archs.teco_resgen import TecoGen
logger = logging.getLogger('base')
@ -115,9 +116,10 @@ def define_G(opt, net_key='network_G', scale=None):
netG = SwitchedGen_arch.BackboneResnet()
elif which_model == "tecogen":
netG = TecoGen(opt_net['nf'], opt_net['scale'])
elif which_model == "basic_resampling_flow_predictor":
netG = BasicResamplingFlowNet(opt_net['nf'], resample_scale=opt_net['resample_scale'])
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG

View File

@ -66,8 +66,9 @@ class ConfigurableStep(Module):
else:
opt_configs = [self.step_opt['optimizer_params']]
nets = [self.training_net]
training = [training]
self.optimizers = []
for net, opt_config in zip(nets, opt_configs):
for net_name, net, opt_config in zip(training, nets, opt_configs):
optim_params = []
for k, v in net.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
@ -84,6 +85,7 @@ class ConfigurableStep(Module):
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
betas=(opt_config['beta1'], opt_config['beta2']))
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
opt._config['network'] = net_name
self.optimizers.append(opt)
# Returns all optimizers used in this step.
@ -189,13 +191,15 @@ class ConfigurableStep(Module):
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
# all self.optimizers.
def do_step(self):
def do_step(self, step):
if not self.grads_generated:
return
self.grads_generated = False
for opt in self.optimizers:
# Optimizers can be opted out in the early stages of training.
after = opt._config['after'] if 'after' in opt._config.keys() else 0
after_network = self.opt['networks'][opt._config['network']]['after'] if 'after' in self.opt['networks'][opt._config['network']].keys() else 0
after = max(after, after_network)
if self.env['step'] < after:
continue
before = opt._config['before'] if 'before' in opt._config.keys() else -1

View File

@ -73,15 +73,21 @@ class FfmpegBackedVideoDataset(data.Dataset):
if self.force_multiple > 1:
assert self.vertical_splits <= 1 # This is not compatible with vertical splits for now.
_, h, w = img_LQ.shape
c, h, w = img_LQ.shape
h_, w_ = h, w
height_removed = h % self.force_multiple
width_removed = w % self.force_multiple
if height_removed != 0:
img_LQ = img_LQ[:, :-height_removed, :]
ref = ref[:, :-height_removed, :]
h_ = self.force_multiple * ((h // self.force_multiple) + 1)
if width_removed != 0:
img_LQ = img_LQ[:, :, :-width_removed]
ref = ref[:, :, :-width_removed]
w_ = self.force_multiple * ((w // self.force_multiple) + 1)
lq_template = torch.zeros(c,h_,w_)
lq_template[:,:h,:w] = img_LQ
ref_template = torch.zeros(c,h_,w_)
ref_template[:,:h,:w] = img_LQ
img_LQ = lq_template
ref = ref_template
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }

View File

@ -278,7 +278,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_artificial_quality.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_with_flow.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)