From 1e0f69e34b0f8c42c7b0a827186c2f6bb16ea06e Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 29 Nov 2020 15:39:50 -0700 Subject: [PATCH] extra_conv in gn discriminator, multiframe support in rrdb. --- codes/models/archs/RRDBNet_arch.py | 6 ++- codes/models/archs/discriminator_vgg_arch.py | 14 +++++- codes/models/networks.py | 4 +- codes/models/steps/injectors.py | 51 ++++++++++++++++++++ codes/train2.py | 2 +- 5 files changed, 72 insertions(+), 5 deletions(-) diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 7c17f17a..f0ee73ed 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -177,6 +177,7 @@ class RRDBNet(nn.Module): feature_channels=64, # Only applicable when headless=True. How many channels are used at the trunk level. output_mode="hq_only", # Options: "hq_only", "hq+features", "features_only" initial_stride=1, + use_ref=False, # When set, a reference image is expected as input and synthesized if not found. Useful for video SR. ): super(RRDBNet, self).__init__() assert output_mode in ['hq_only', 'hq+features', 'features_only'] @@ -186,7 +187,8 @@ class RRDBNet(nn.Module): self.scale = scale self.in_channels = in_channels self.output_mode = output_mode - first_conv_stride = initial_stride if in_channels <= 4 else scale + self.use_ref = use_ref + first_conv_stride = initial_stride if not self.use_ref else scale first_conv_ksize = 3 if first_conv_stride == 1 else 7 first_conv_padding = 1 if first_conv_stride == 1 else 3 if headless: @@ -242,7 +244,7 @@ class RRDBNet(nn.Module): feat = x else: # "Normal" mode -> image input. - if self.in_channels > 4: + if self.use_ref: x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic") if ref is None: ref = torch.zeros_like(x_lg) diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index eadcf559..67755d9a 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -83,7 +83,7 @@ class Discriminator_VGG_128(nn.Module): class Discriminator_VGG_128_GN(nn.Module): # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. - def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False): + def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False, extra_conv=False): super(Discriminator_VGG_128_GN, self).__init__() self.do_checkpointing = do_checkpointing @@ -111,6 +111,14 @@ class Discriminator_VGG_128_GN(nn.Module): self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True) self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True) + + self.extra_conv = extra_conv + if extra_conv: + self.conv5_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn5_0 = nn.GroupNorm(8, nf * 8, affine=True) + self.conv5_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn5_1 = nn.GroupNorm(8, nf * 8, affine=True) + input_img_factor = input_img_factor // 2 final_nf = nf * 8 # activation function @@ -136,6 +144,10 @@ class Discriminator_VGG_128_GN(nn.Module): fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) + + if self.extra_conv: + fea = self.lrelu(self.bn5_0(self.conv5_0(fea))) + fea = self.lrelu(self.bn5_1(self.conv5_1(fea))) return fea def forward(self, x): diff --git a/codes/models/networks.py b/codes/models/networks.py index 9ce6d909..c51a0336 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -196,7 +196,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False): if which_model == 'discriminator_vgg_128': netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv']) elif which_model == 'discriminator_vgg_128_gn': - netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) + extra_conv = opt_net['extra_conv'] if 'extra_conv' in opt_net.keys() else False + netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], + input_img_factor=img_sz / 128, extra_conv=extra_conv) if wrap: netD = GradDiscWrapper(netD) elif which_model == 'discriminator_vgg_128_gn_checkpointed': diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 6279958f..cc9f6e54 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -56,6 +56,8 @@ def create_injector(opt_inject, env): return BatchRotateInjector(opt_inject, env) elif type == 'sr_diffs': return SrDiffsInjector(opt_inject, env) + elif type == 'multiframe_combiner': + return MultiFrameCombiner(opt_inject, env) else: raise NotImplementedError @@ -419,3 +421,52 @@ class SrDiffsInjector(Injector): elif self.mode == 'recombine': combined = resampled_lq + hq return {self.output: combined} + + +class MultiFrameCombiner(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.mode = opt['mode'] + self.dim = opt['dim'] if 'dim' in opt.keys() else None + self.flow = opt['flow'] + self.in_lq_key = opt['in'] + self.in_hq_key = opt['in_hq'] + self.out_lq_key = opt['out'] + self.out_hq_key = opt['out_hq'] + from models.flownet2.networks.resample2d_package.resample2d import Resample2d + self.resampler = Resample2d() + + def combine(self, state): + flow = self.env['generators'][self.flow] + lq = state[self.in_lq_key] + hq = state[self.in_hq_key] + b, f, c, h, w = lq.shape + center = f // 2 + center_img = lq[:,center,:,:,:] + imgs = [center_img] + with torch.no_grad(): + for i in range(f): + if i == center: + continue + nimg = lq[:,i,:,:,:] + flowfield = flow(torch.stack([center_img, nimg], dim=2).float()) + nimg = self.resampler(nimg, flowfield) + imgs.append(nimg) + hq_out = hq[:,center,:,:,:] + return {self.out_lq_key: torch.cat(imgs, dim=1), + self.out_hq_key: hq_out, + self.out_lq_key + "_flow_sample": torch.cat(imgs, dim=0)} + + def synthesize(self, state): + lq = state[self.in_lq_key] + return { + self.out_lq_key: lq.repeat(1, self.dim, 1, 1) + } + + def forward(self, state): + if self.mode == "synthesize": + return self.synthesize(state) + elif self.mode == "combine": + return self.combine(state) + else: + raise NotImplementedError diff --git a/codes/train2.py b/codes/train2.py index 10a79d7e..c8ffb242 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_exd_imgsetext_rrdb4x_6bl_2stride/train_exd_imgsetext_rrdb4x_6bl_2stride.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb4x_2stride_multiframe.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()