forked from mrq/DL-Art-School
More work in support of training flow networks in tandem with generators
This commit is contained in:
parent
c21088e238
commit
df47d6cbbb
|
@ -194,6 +194,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
net_enabled = name in nets_to_train
|
||||
if net_enabled:
|
||||
enabled += 1
|
||||
# Networks can opt out of training before a certain iteration by declaring 'after' in their definition.
|
||||
if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']:
|
||||
net_enabled = False
|
||||
for p in net.parameters():
|
||||
if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"):
|
||||
p.requires_grad = net_enabled
|
||||
|
@ -225,7 +228,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
# And finally perform optimization.
|
||||
[e.before_optimize(state) for e in self.experiments]
|
||||
s.do_step()
|
||||
s.do_step(step)
|
||||
[e.after_optimize(state) for e in self.experiments]
|
||||
|
||||
# Record visual outputs for usage in debugging and testing.
|
||||
|
|
98
codes/models/archs/pyramid_arch.py
Normal file
98
codes/models/archs/pyramid_arch.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ExpansionBlock
|
||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||
from utils.util import checkpoint
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Pyramid(nn.Module):
|
||||
def __init__(self, nf, depth, processing_convs_per_layer, processing_at_point, scale_per_level=2, block=ConvGnLelu,
|
||||
norm=True, return_outlevels=False):
|
||||
super(Pyramid, self).__init__()
|
||||
levels = []
|
||||
current_filters = nf
|
||||
self.return_outlevels = return_outlevels
|
||||
for d in range(depth):
|
||||
level = [block(current_filters, int(current_filters*scale_per_level), kernel_size=3, stride=2, activation=True, norm=False, bias=False)]
|
||||
current_filters = int(current_filters*scale_per_level)
|
||||
for pc in range(processing_convs_per_layer):
|
||||
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
||||
levels.append(nn.Sequential(*level))
|
||||
self.downsamples = nn.ModuleList(levels)
|
||||
if processing_at_point > 0:
|
||||
point_processor = []
|
||||
for p in range(processing_at_point):
|
||||
point_processor.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
||||
self.point_processor = nn.Sequential(*point_processor)
|
||||
else:
|
||||
self.point_processor = None
|
||||
levels = []
|
||||
for d in range(depth):
|
||||
level = [ExpansionBlock(current_filters, int(current_filters / scale_per_level), block=block)]
|
||||
current_filters = int(current_filters / scale_per_level)
|
||||
for pc in range(processing_convs_per_layer):
|
||||
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
||||
levels.append(nn.ModuleList(level))
|
||||
self.upsamples = nn.ModuleList(levels)
|
||||
|
||||
def forward(self, x):
|
||||
passthroughs = []
|
||||
fea = x
|
||||
for lvl in self.downsamples:
|
||||
passthroughs.append(fea)
|
||||
fea = lvl(fea)
|
||||
out_levels = []
|
||||
fea = self.point_processor(fea)
|
||||
for i, lvl in enumerate(self.upsamples):
|
||||
out_levels.append(fea)
|
||||
for j, sublvl in enumerate(lvl):
|
||||
if j == 0:
|
||||
fea = sublvl(fea, passthroughs[-1-i])
|
||||
else:
|
||||
fea = sublvl(fea)
|
||||
|
||||
out_levels.append(fea)
|
||||
|
||||
if self.return_outlevels:
|
||||
return tuple(out_levels)
|
||||
else:
|
||||
return fea
|
||||
|
||||
|
||||
class BasicResamplingFlowNet(nn.Module):
|
||||
def create_termini(self, filters):
|
||||
return nn.Sequential(ConvGnLelu(int(filters), 2, kernel_size=3, activation=False, norm=False, bias=True),
|
||||
nn.Tanh())
|
||||
|
||||
def __init__(self, nf, resample_scale=1):
|
||||
super(BasicResamplingFlowNet, self).__init__()
|
||||
self.initial_conv = ConvGnLelu(6, nf, kernel_size=7, activation=False, norm=False, bias=True)
|
||||
self.pyramid = Pyramid(nf, 3, 0, 1, 1.5, return_outlevels=True)
|
||||
self.termini = nn.ModuleList([self.create_termini(nf*1.5**3),
|
||||
self.create_termini(nf*1.5**2),
|
||||
self.create_termini(nf*1.5)])
|
||||
self.terminus = nn.Sequential(ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=True),
|
||||
ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False),
|
||||
ConvGnLelu(nf, nf//2, kernel_size=3, activation=False, norm=False, bias=True),
|
||||
ConvGnLelu(nf//2, 2, kernel_size=3, activation=False, norm=False, bias=True),
|
||||
nn.Tanh())
|
||||
self.scale = resample_scale
|
||||
self.resampler = Resample2d()
|
||||
|
||||
def forward(self, left, right):
|
||||
fea = self.initial_conv(torch.cat([left, right], dim=1))
|
||||
levels = checkpoint(self.pyramid, fea)
|
||||
flos = []
|
||||
compares = []
|
||||
for i, level in enumerate(levels):
|
||||
if i == 3:
|
||||
flow = checkpoint(self.terminus, level) * self.scale
|
||||
else:
|
||||
flow = self.termini[i](level) * self.scale
|
||||
img_scale = 1/2**(3-i)
|
||||
flos.append(self.resampler(F.interpolate(left, scale_factor=img_scale, mode="area").float(), flow.float()))
|
||||
compares.append(F.interpolate(right, scale_factor=img_scale, mode="area"))
|
||||
flos_structural_var = torch.var(flos[-1], dim=[-1,-2])
|
||||
return flos, compares, flos_structural_var
|
|
@ -19,6 +19,7 @@ import models.archs.panet.panet as panet
|
|||
import models.archs.rcan as rcan
|
||||
import models.archs.ChainedEmbeddingGen as chained
|
||||
from models.archs import srg2_classic
|
||||
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
||||
from models.archs.teco_resgen import TecoGen
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
@ -115,9 +116,10 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
netG = SwitchedGen_arch.BackboneResnet()
|
||||
elif which_model == "tecogen":
|
||||
netG = TecoGen(opt_net['nf'], opt_net['scale'])
|
||||
elif which_model == "basic_resampling_flow_predictor":
|
||||
netG = BasicResamplingFlowNet(opt_net['nf'], resample_scale=opt_net['resample_scale'])
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
|
||||
return netG
|
||||
|
||||
|
||||
|
|
|
@ -66,8 +66,9 @@ class ConfigurableStep(Module):
|
|||
else:
|
||||
opt_configs = [self.step_opt['optimizer_params']]
|
||||
nets = [self.training_net]
|
||||
training = [training]
|
||||
self.optimizers = []
|
||||
for net, opt_config in zip(nets, opt_configs):
|
||||
for net_name, net, opt_config in zip(training, nets, opt_configs):
|
||||
optim_params = []
|
||||
for k, v in net.named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
|
@ -84,6 +85,7 @@ class ConfigurableStep(Module):
|
|||
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
|
||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
|
||||
opt._config['network'] = net_name
|
||||
self.optimizers.append(opt)
|
||||
|
||||
# Returns all optimizers used in this step.
|
||||
|
@ -189,13 +191,15 @@ class ConfigurableStep(Module):
|
|||
|
||||
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
||||
# all self.optimizers.
|
||||
def do_step(self):
|
||||
def do_step(self, step):
|
||||
if not self.grads_generated:
|
||||
return
|
||||
self.grads_generated = False
|
||||
for opt in self.optimizers:
|
||||
# Optimizers can be opted out in the early stages of training.
|
||||
after = opt._config['after'] if 'after' in opt._config.keys() else 0
|
||||
after_network = self.opt['networks'][opt._config['network']]['after'] if 'after' in self.opt['networks'][opt._config['network']].keys() else 0
|
||||
after = max(after, after_network)
|
||||
if self.env['step'] < after:
|
||||
continue
|
||||
before = opt._config['before'] if 'before' in opt._config.keys() else -1
|
||||
|
|
|
@ -73,15 +73,21 @@ class FfmpegBackedVideoDataset(data.Dataset):
|
|||
|
||||
if self.force_multiple > 1:
|
||||
assert self.vertical_splits <= 1 # This is not compatible with vertical splits for now.
|
||||
_, h, w = img_LQ.shape
|
||||
c, h, w = img_LQ.shape
|
||||
h_, w_ = h, w
|
||||
height_removed = h % self.force_multiple
|
||||
width_removed = w % self.force_multiple
|
||||
if height_removed != 0:
|
||||
img_LQ = img_LQ[:, :-height_removed, :]
|
||||
ref = ref[:, :-height_removed, :]
|
||||
h_ = self.force_multiple * ((h // self.force_multiple) + 1)
|
||||
if width_removed != 0:
|
||||
img_LQ = img_LQ[:, :, :-width_removed]
|
||||
ref = ref[:, :, :-width_removed]
|
||||
w_ = self.force_multiple * ((w // self.force_multiple) + 1)
|
||||
lq_template = torch.zeros(c,h_,w_)
|
||||
lq_template[:,:h,:w] = img_LQ
|
||||
ref_template = torch.zeros(c,h_,w_)
|
||||
ref_template[:,:h,:w] = img_LQ
|
||||
img_LQ = lq_template
|
||||
ref = ref_template
|
||||
|
||||
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
|
||||
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
|
||||
|
||||
|
|
|
@ -278,7 +278,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_artificial_quality.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_with_flow.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user