diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index c28cfde7..951f06c5 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -194,6 +194,9 @@ class ExtensibleTrainer(BaseModel): net_enabled = name in nets_to_train if net_enabled: enabled += 1 + # Networks can opt out of training before a certain iteration by declaring 'after' in their definition. + if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']: + net_enabled = False for p in net.parameters(): if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"): p.requires_grad = net_enabled @@ -225,7 +228,7 @@ class ExtensibleTrainer(BaseModel): # And finally perform optimization. [e.before_optimize(state) for e in self.experiments] - s.do_step() + s.do_step(step) [e.after_optimize(state) for e in self.experiments] # Record visual outputs for usage in debugging and testing. diff --git a/codes/models/archs/pyramid_arch.py b/codes/models/archs/pyramid_arch.py new file mode 100644 index 00000000..9e0ce268 --- /dev/null +++ b/codes/models/archs/pyramid_arch.py @@ -0,0 +1,98 @@ +import torch +from torch import nn + +from models.archs.arch_util import ConvGnLelu, UpconvBlock, ExpansionBlock +from models.flownet2.networks.resample2d_package.resample2d import Resample2d +from utils.util import checkpoint +import torch.nn.functional as F + + +class Pyramid(nn.Module): + def __init__(self, nf, depth, processing_convs_per_layer, processing_at_point, scale_per_level=2, block=ConvGnLelu, + norm=True, return_outlevels=False): + super(Pyramid, self).__init__() + levels = [] + current_filters = nf + self.return_outlevels = return_outlevels + for d in range(depth): + level = [block(current_filters, int(current_filters*scale_per_level), kernel_size=3, stride=2, activation=True, norm=False, bias=False)] + current_filters = int(current_filters*scale_per_level) + for pc in range(processing_convs_per_layer): + level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False)) + levels.append(nn.Sequential(*level)) + self.downsamples = nn.ModuleList(levels) + if processing_at_point > 0: + point_processor = [] + for p in range(processing_at_point): + point_processor.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False)) + self.point_processor = nn.Sequential(*point_processor) + else: + self.point_processor = None + levels = [] + for d in range(depth): + level = [ExpansionBlock(current_filters, int(current_filters / scale_per_level), block=block)] + current_filters = int(current_filters / scale_per_level) + for pc in range(processing_convs_per_layer): + level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False)) + levels.append(nn.ModuleList(level)) + self.upsamples = nn.ModuleList(levels) + + def forward(self, x): + passthroughs = [] + fea = x + for lvl in self.downsamples: + passthroughs.append(fea) + fea = lvl(fea) + out_levels = [] + fea = self.point_processor(fea) + for i, lvl in enumerate(self.upsamples): + out_levels.append(fea) + for j, sublvl in enumerate(lvl): + if j == 0: + fea = sublvl(fea, passthroughs[-1-i]) + else: + fea = sublvl(fea) + + out_levels.append(fea) + + if self.return_outlevels: + return tuple(out_levels) + else: + return fea + + +class BasicResamplingFlowNet(nn.Module): + def create_termini(self, filters): + return nn.Sequential(ConvGnLelu(int(filters), 2, kernel_size=3, activation=False, norm=False, bias=True), + nn.Tanh()) + + def __init__(self, nf, resample_scale=1): + super(BasicResamplingFlowNet, self).__init__() + self.initial_conv = ConvGnLelu(6, nf, kernel_size=7, activation=False, norm=False, bias=True) + self.pyramid = Pyramid(nf, 3, 0, 1, 1.5, return_outlevels=True) + self.termini = nn.ModuleList([self.create_termini(nf*1.5**3), + self.create_termini(nf*1.5**2), + self.create_termini(nf*1.5)]) + self.terminus = nn.Sequential(ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=True), + ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False), + ConvGnLelu(nf, nf//2, kernel_size=3, activation=False, norm=False, bias=True), + ConvGnLelu(nf//2, 2, kernel_size=3, activation=False, norm=False, bias=True), + nn.Tanh()) + self.scale = resample_scale + self.resampler = Resample2d() + + def forward(self, left, right): + fea = self.initial_conv(torch.cat([left, right], dim=1)) + levels = checkpoint(self.pyramid, fea) + flos = [] + compares = [] + for i, level in enumerate(levels): + if i == 3: + flow = checkpoint(self.terminus, level) * self.scale + else: + flow = self.termini[i](level) * self.scale + img_scale = 1/2**(3-i) + flos.append(self.resampler(F.interpolate(left, scale_factor=img_scale, mode="area").float(), flow.float())) + compares.append(F.interpolate(right, scale_factor=img_scale, mode="area")) + flos_structural_var = torch.var(flos[-1], dim=[-1,-2]) + return flos, compares, flos_structural_var diff --git a/codes/models/networks.py b/codes/models/networks.py index 32280ab2..dc103b74 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -19,6 +19,7 @@ import models.archs.panet.panet as panet import models.archs.rcan as rcan import models.archs.ChainedEmbeddingGen as chained from models.archs import srg2_classic +from models.archs.pyramid_arch import BasicResamplingFlowNet from models.archs.teco_resgen import TecoGen logger = logging.getLogger('base') @@ -115,9 +116,10 @@ def define_G(opt, net_key='network_G', scale=None): netG = SwitchedGen_arch.BackboneResnet() elif which_model == "tecogen": netG = TecoGen(opt_net['nf'], opt_net['scale']) + elif which_model == "basic_resampling_flow_predictor": + netG = BasicResamplingFlowNet(opt_net['nf'], resample_scale=opt_net['resample_scale']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) - return netG diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index e3faadbe..3767b89b 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -66,8 +66,9 @@ class ConfigurableStep(Module): else: opt_configs = [self.step_opt['optimizer_params']] nets = [self.training_net] + training = [training] self.optimizers = [] - for net, opt_config in zip(nets, opt_configs): + for net_name, net, opt_config in zip(training, nets, opt_configs): optim_params = [] for k, v in net.named_parameters(): # can optimize for a part of the model if v.requires_grad: @@ -84,6 +85,7 @@ class ConfigurableStep(Module): opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'], betas=(opt_config['beta1'], opt_config['beta2'])) opt._config = opt_config # This is a bit seedy, but we will need these configs later. + opt._config['network'] = net_name self.optimizers.append(opt) # Returns all optimizers used in this step. @@ -189,13 +191,15 @@ class ConfigurableStep(Module): # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps() # all self.optimizers. - def do_step(self): + def do_step(self, step): if not self.grads_generated: return self.grads_generated = False for opt in self.optimizers: # Optimizers can be opted out in the early stages of training. after = opt._config['after'] if 'after' in opt._config.keys() else 0 + after_network = self.opt['networks'][opt._config['network']]['after'] if 'after' in self.opt['networks'][opt._config['network']].keys() else 0 + after = max(after, after_network) if self.env['step'] < after: continue before = opt._config['before'] if 'before' in opt._config.keys() else -1 diff --git a/codes/process_video.py b/codes/process_video.py index 71bc25c5..fd376966 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -73,15 +73,21 @@ class FfmpegBackedVideoDataset(data.Dataset): if self.force_multiple > 1: assert self.vertical_splits <= 1 # This is not compatible with vertical splits for now. - _, h, w = img_LQ.shape + c, h, w = img_LQ.shape + h_, w_ = h, w height_removed = h % self.force_multiple width_removed = w % self.force_multiple if height_removed != 0: - img_LQ = img_LQ[:, :-height_removed, :] - ref = ref[:, :-height_removed, :] + h_ = self.force_multiple * ((h // self.force_multiple) + 1) if width_removed != 0: - img_LQ = img_LQ[:, :, :-width_removed] - ref = ref[:, :, :-width_removed] + w_ = self.force_multiple * ((w // self.force_multiple) + 1) + lq_template = torch.zeros(c,h_,w_) + lq_template[:,:h,:w] = img_LQ + ref_template = torch.zeros(c,h_,w_) + ref_template[:,:h,:w] = img_LQ + img_LQ = lq_template + ref = ref_template + return {'LQ': img_LQ, 'lq_fullsize_ref': ref, 'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) } diff --git a/codes/train2.py b/codes/train2.py index c67a20d5..d068c125 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -278,7 +278,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_artificial_quality.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass_with_flow.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)