diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index cddd1c4f..29c1718f 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -6,7 +6,7 @@ import torch.nn.functional as F import torchvision from torch.utils.checkpoint import checkpoint_sequential -from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu +from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu class ResidualDenseBlock(nn.Module): @@ -60,11 +60,16 @@ class RRDB(nn.Module): growth_channels (int): Channels for each growth. """ - def __init__(self, mid_channels, growth_channels=32): + def __init__(self, mid_channels, growth_channels=32, reduce_to=None): super(RRDB, self).__init__() self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) + if reduce_to is not None: + self.reducer = ConvGnLelu(mid_channels, reduce_to, kernel_size=3, activation=False, norm=False, bias=True) + self.recover_ch = mid_channels - reduce_to + else: + self.reducer = None def forward(self, x): """Forward function. @@ -78,6 +83,10 @@ class RRDB(nn.Module): out = self.rdb1(x) out = self.rdb2(out) out = self.rdb3(out) + if self.reducer is not None: + out = self.reducer(out) + b, f, h, w = out.shape + out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1) # Emperically, we use 0.2 to scale the residual for better performance return out * 0.2 + x @@ -92,12 +101,19 @@ class RRDBWithBypass(nn.Module): growth_channels (int): Channels for each growth. """ - def __init__(self, mid_channels, growth_channels=32): + def __init__(self, mid_channels, growth_channels=32, reduce_to=None): super(RRDBWithBypass, self).__init__() self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) - self.bypass = nn.Sequential(ConvGnSilu(mid_channels*2, mid_channels, kernel_size=3, bias=True, activation=True, norm=True), + if reduce_to is not None: + self.reducer = ConvGnLelu(mid_channels, reduce_to, kernel_size=3, activation=False, norm=False, bias=True) + self.recover_ch = mid_channels - reduce_to + bypass_channels = mid_channels + reduce_to + else: + self.reducer = None + bypass_channels = mid_channels * 2 + self.bypass = nn.Sequential(ConvGnSilu(bypass_channels, mid_channels, kernel_size=3, bias=True, activation=True, norm=True), ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False), ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False), nn.Sigmoid()) @@ -114,8 +130,15 @@ class RRDBWithBypass(nn.Module): out = self.rdb1(x) out = self.rdb2(out) out = self.rdb3(out) + + if self.reducer is not None: + out = self.reducer(out) + b, f, h, w = out.shape + out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1) + bypass = self.bypass(torch.cat([x, out], dim=1)) self.bypass_map = bypass.detach().clone() + # Empirically, we use 0.2 to scale the residual for better performance return out * 0.2 * bypass + x @@ -143,30 +166,45 @@ class RRDBNet(nn.Module): num_blocks=23, growth_channels=32, body_block=RRDB, - blocks_per_checkpoint=4, + blocks_per_checkpoint=1, scale=4, - additive_mode="not_additive" # Options: "not", "additive", "additive_enforced" + additive_mode="not", # Options: "not", "additive", "additive_enforced" + headless=False, + feature_channels=64, # Only applicable when headless=True. How many channels are used at the trunk level. + output_mode="hq_only", # Options: "hq_only", "hq+features", "features_only" ): super(RRDBNet, self).__init__() + assert output_mode in ['hq_only', 'hq+features', 'features_only'] + assert additive_mode in ['not', 'additive', 'additive_enforced'] self.num_blocks = num_blocks self.blocks_per_checkpoint = blocks_per_checkpoint self.scale = scale self.in_channels = in_channels + self.output_mode = output_mode first_conv_stride = 1 if in_channels <= 4 else scale first_conv_ksize = 3 if first_conv_stride == 1 else 7 first_conv_padding = 1 if first_conv_stride == 1 else 3 - self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding) + if headless: + self.conv_first = None + self.reduce_ch = feature_channels + reduce_to = feature_channels + self.conv_ref_first = ConvGnLelu(3, feature_channels, 7, stride=2, norm=False, activation=False, bias=True) + else: + self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding) + self.reduce_ch = mid_channels + reduce_to = None self.body = make_layer( body_block, num_blocks, mid_channels=mid_channels, - growth_channels=growth_channels) - self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) + growth_channels=growth_channels, + reduce_to=reduce_to) + self.conv_body = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1) # upsample - self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) - self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1) + self.conv_up1 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1) + self.conv_up2 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1) + self.conv_hr = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1) + self.conv_last = nn.Conv2d(self.reduce_ch, out_channels, 3, 1, 1) self.additive_mode = additive_mode if additive_mode == "additive_enforced": @@ -178,7 +216,8 @@ class RRDBNet(nn.Module): self.conv_first, self.conv_body, self.conv_up1, self.conv_up2, self.conv_hr, self.conv_last ]: - default_init_weights(m, 0.1) + if m is not None: + default_init_weights(m, 0.1) def forward(self, x, ref=None): """Forward function. @@ -189,25 +228,39 @@ class RRDBNet(nn.Module): Returns: Tensor: Forward results. """ - if self.in_channels > 4: - x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic") - if ref is None: - ref = torch.zeros_like(x_lg) - x_lg = torch.cat([x_lg, ref], dim=1) + if self.conv_first is None: + # Headless mode -> embedding inputs. + if ref is not None: + ref = self.conv_ref_first(ref) + feat = torch.cat([x, ref], dim=1) + else: + feat = x else: - x_lg = x - feat = self.conv_first(x_lg) - body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)) + # "Normal" mode -> image input. + if self.in_channels > 4: + x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic") + if ref is None: + ref = torch.zeros_like(x_lg) + x_lg = torch.cat([x_lg, ref], dim=1) + else: + x_lg = x + feat = self.conv_first(x_lg) + feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat) + feat = feat[:, :self.reduce_ch] + body_feat = self.conv_body(feat) feat = feat + body_feat + if self.output_mode == "features_only": + return feat + # upsample - feat = self.lrelu( + out = self.lrelu( self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) if self.scale == 4: - feat = self.lrelu( - self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.lrelu( + self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))) else: - feat = self.lrelu(self.conv_up2(feat)) - out = self.conv_last(self.lrelu(self.conv_hr(feat))) + out = self.lrelu(self.conv_up2(out)) + out = self.conv_last(self.lrelu(self.conv_hr(out))) if "additive" in self.additive_mode: x_interp = F.interpolate(x, scale_factor=self.scale, mode='bilinear') if self.additive_mode == 'additive': @@ -216,9 +269,14 @@ class RRDBNet(nn.Module): out_pooled = self.add_enforced_pool(out) out = out - F.interpolate(out_pooled, scale_factor=self.scale, mode='nearest') out = out + x_interp + + if self.output_mode == "hq+features": + return out, feat return out def visual_dbg(self, step, path): for i, bm in enumerate(self.body): if hasattr(bm, 'bypass_map'): torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) + + diff --git a/codes/models/networks.py b/codes/models/networks.py index 41aa7ab2..26a35392 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -39,14 +39,17 @@ def define_G(opt, opt_net, scale=None): nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RRDBNet': additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive' + output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], - mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode) + mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, + output_mode=output_mode) elif which_model == 'RRDBNetBypass': additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not' + output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass, blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'], - additive_mode=additive_mode) + additive_mode=additive_mode, output_mode=output_mode) elif which_model == 'rcan': #args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats opt_net['rgb_range'] = 255 @@ -110,8 +113,6 @@ def define_G(opt, opt_net, scale=None): netG = SwitchedGen_arch.BackboneResnet() elif which_model == "tecogen": netG = TecoGen(opt_net['nf'], opt_net['scale']) - elif which_model == "basic_resampling_flow_predictor": - netG = BasicResamplingFlowNet(opt_net['nf'], resample_scale=opt_net['resample_scale']) elif which_model == "rrdb_with_latent": netG = RRDBNetWithLatent(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], @@ -153,6 +154,11 @@ def define_G(opt, opt_net, scale=None): netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'], blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path']) + elif which_model == 'rrdb_centipede': + output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' + netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], + mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'], + headless=True, output_mode=output_mode) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 06165aba..6279958f 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -285,6 +285,7 @@ class ForEachInjector(Injector): o['in'] = '_in' o['out'] = '_out' self.injector = create_injector(o, self.env) + self.aslist = opt['aslist'] if 'aslist' in opt.keys() else False def forward(self, state): injs = [] @@ -293,7 +294,10 @@ class ForEachInjector(Injector): for i in range(inputs.shape[1]): st['_in'] = inputs[:, i] injs.append(self.injector(st)['_out']) - return {self.output: torch.stack(injs, dim=1)} + if self.aslist: + return {self.output: injs} + else: + return {self.output: torch.stack(injs, dim=1)} class ConstantInjector(Injector): diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 448d5c09..860bbbe7 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -140,7 +140,7 @@ class ConfigurableStep(Module): # Don't do injections tagged with 'after' or 'before' when we are out of spec. if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \ 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before'] or \ - 'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0: + 'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0: continue injected = inj(local_state) local_state.update(injected) diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 16e82413..70748328 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -44,10 +44,12 @@ class RecurrentImageGeneratorSequenceInjector(Injector): super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) self.flow = opt['flow_network'] self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 - self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 self.recurrent_index = opt['recurrent_index'] + self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 + self.output_recurrent_index = opt['output_recurrent_index'] if 'output_recurrent_index' in opt.keys() else self.output_hq_index self.scale = opt['scale'] self.resample = Resample2d() + self.flow_key = opt['flow_input_key'] if 'flow_input_key' in opt.keys() else None self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'. self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True self.hq_recurrent = opt['hq_recurrent'] if 'hq_recurrent' in opt.keys() else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index. @@ -82,20 +84,21 @@ class RecurrentImageGeneratorSequenceInjector(Injector): else: input = extract_inputs_index(inputs, i) with torch.no_grad() and autocast(enabled=False): - # This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is - # made here that if you are operating at 4x scale, your inputs are 32px x 32px - if self.scale >= 4: - flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic') + if self.flow_key is not None: + flow_input = state[self.flow_key][:, i] else: flow_input = input[self.input_lq_index] - reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic') + reduced_recurrent = F.interpolate(hq_recurrent, scale_factor=1/self.scale, mode='bicubic') flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float() - flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic') + flowfield = flow(flow_input) + if recurrent_input.shape[-1] != flow_input.shape[-1]: + flowfield = F.interpolate(flowfield, scale_factor=self.scale, mode='bicubic') recurrent_input = self.resample(recurrent_input.float(), flowfield) input[self.recurrent_index] = recurrent_input if self.env['step'] % 50 == 0: - self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) - debug_index += 1 + if input[self.input_lq_index].shape[1] == 3: # Only debug this if we're dealing with images. + self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.hq_recurrent], debug_index) + debug_index += 1 with autocast(enabled=self.env['opt']['fp16']): gen_out = gen(*input) @@ -104,7 +107,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector): gen_out = [gen_out] for i, out_key in enumerate(self.output): results[out_key].append(gen_out[i]) - recurrent_input = gen_out[self.output_hq_index] + hq_recurrent = gen_out[self.output_hq_index] + recurrent_input = gen_out[self.output_recurrent_index] # Now go backwards, skipping the last element (it's already stored in recurrent_input) if self.do_backwards: @@ -113,20 +117,21 @@ class RecurrentImageGeneratorSequenceInjector(Injector): input = extract_inputs_index(inputs, i) with torch.no_grad(): with autocast(enabled=False): - # This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is - # made here that if you are operating at 4x scale, your inputs are 32px x 32px - if self.scale >= 4: - flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic') + if self.flow_key is not None: + flow_input = state[self.flow_key][:, i] else: flow_input = input[self.input_lq_index] - reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic') + reduced_recurrent = F.interpolate(hq_recurrent, scale_factor=1/self.scale, mode='bicubic') flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float() - flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic') + flowfield = flow(flow_input) + if recurrent_input.shape[-1] != flow_input.shape[-1]: + flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') recurrent_input = self.resample(recurrent_input.float(), flowfield) input[self.recurrent_index] = recurrent_input if self.env['step'] % 50 == 0: - self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) - debug_index += 1 + if input[self.input_lq_index].shape[1] == 3: # Only debug this if we're dealing with images. + self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) + debug_index += 1 with autocast(enabled=self.env['opt']['fp16']): gen_out = gen(*input) @@ -135,7 +140,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector): gen_out = [gen_out] for i, out_key in enumerate(self.output): results[out_key].append(gen_out[i]) - recurrent_input = gen_out[self.output_hq_index] + hq_recurrent = gen_out[self.output_hq_index] + recurrent_input = gen_out[self.output_recurrent_index] final_results = {} # Include 'hq_batched' here - because why not... Don't really need a separate injector for this. diff --git a/codes/train.py b/codes/train.py index 38c1bf2b..251a079d 100644 --- a/codes/train.py +++ b/codes/train.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_v2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()