From f406a5dd4c820e20a0f750467c504892bc252385 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 13 Nov 2020 20:11:50 -0700 Subject: [PATCH] Mods to support stylegan2 in SR mode --- codes/models/archs/stylegan2.py | 38 ++++++++++++++++++++++----------- codes/models/networks.py | 5 +++-- codes/models/steps/injectors.py | 24 +++++++++++++++++++++ codes/models/steps/losses.py | 11 ++++++++-- codes/train2.py | 2 +- 5 files changed, 62 insertions(+), 18 deletions(-) diff --git a/codes/models/archs/stylegan2.py b/codes/models/archs/stylegan2.py index 30d36bb3..be6004f2 100644 --- a/codes/models/archs/stylegan2.py +++ b/codes/models/archs/stylegan2.py @@ -403,10 +403,15 @@ class Conv2DMod(nn.Module): class GeneratorBlock(nn.Module): - def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False): + def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None + self.structure_input = structure_input + if self.structure_input: + self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1) + input_channels = input_channels * 2 + self.to_style1 = nn.Linear(latent_dim, input_channels) self.to_noise1 = nn.Linear(1, filters) self.conv1 = Conv2DMod(input_channels, filters, 3) @@ -418,10 +423,15 @@ class GeneratorBlock(nn.Module): self.activation = leaky_relu() self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba) - def forward(self, x, prev_rgb, istyle, inoise): + def forward(self, x, prev_rgb, istyle, inoise, structure_input=None): if exists(self.upsample): x = self.upsample(x) + if self.structure_input: + s = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest") + s = self.structure_conv(s) + x = torch.cat([x, s], dim=1) + inoise = inoise[:, :x.shape[2], :x.shape[3], :] noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1)) noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1)) @@ -466,7 +476,7 @@ class DiscriminatorBlock(nn.Module): class Generator(nn.Module): def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False, - fmap_max=512): + fmap_max=512, structure_input=False): super().__init__() self.image_size = image_size self.latent_dim = latent_dim @@ -506,11 +516,12 @@ class Generator(nn.Module): out_chan, upsample=not_first, upsample_rgb=not_last, - rgba=transparent + rgba=transparent, + structure_input=structure_input ) self.blocks.append(block) - def forward(self, styles, input_noise): + def forward(self, styles, input_noise, structure_input=None): batch_size = styles.shape[0] image_size = self.image_size @@ -527,17 +538,19 @@ class Generator(nn.Module): for style, block, attn in zip(styles, self.blocks, self.attns): if exists(attn): x = attn(x) - x, rgb = checkpoint(block, x, rgb, style, input_noise) + x, rgb = checkpoint(block, x, rgb, style, input_noise, structure_input) return rgb # Wrapper that combines style vectorizer with the actual generator. class StyleGan2GeneratorWithLatent(nn.Module): - def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False, attn_layers=[], no_const=False, fmap_max=512): + def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False, + attn_layers=[], no_const=False, fmap_max=512, structure_input=False): super().__init__() self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp) - self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max) + self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max, + structure_input=structure_input) self.mixed_prob = .9 self._init_weights() @@ -559,7 +572,7 @@ class StyleGan2GeneratorWithLatent(nn.Module): # To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format: # b,f,h,w. - def forward(self, x): + def forward(self, x, structure_input=None): b, f, h, w = x.shape full_random_latents = True @@ -583,7 +596,7 @@ class StyleGan2GeneratorWithLatent(nn.Module): w_styles = self.styles_def_to_tensor(w_space) # The underlying model expects the noise as b,h,w,1. Make it so. - return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3)), w_styles + return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input), w_styles def _init_weights(self): for m in self.modules(): @@ -599,13 +612,12 @@ class StyleGan2GeneratorWithLatent(nn.Module): class StyleGan2Discriminator(nn.Module): def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], - transparent=False, fmap_max=512): + transparent=False, fmap_max=512, input_filters=3): super().__init__() num_layers = int(log2(image_size) - 1) - num_init_filters = 3 if not transparent else 4 blocks = [] - filters = [num_init_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)] + filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) diff --git a/codes/models/networks.py b/codes/models/networks.py index 4c6a26ea..ed015203 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -133,8 +133,9 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == "linear_latent_estimator": netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf']) elif which_model == 'stylegan2': + is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False netG = StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], - style_depth=opt_net['style_depth']) + style_depth=opt_net['style_depth'], structure_input=is_structured) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG @@ -194,7 +195,7 @@ def define_D_net(opt_net, img_sz=None, wrap=False): elif which_model == "pyramid_disc": netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf']) elif which_model == "stylegan2_discriminator": - disc = StyleGan2Discriminator(image_size=opt_net['image_size']) + disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc']) netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 272e7ec9..5d584980 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -54,6 +54,8 @@ def create_injector(opt_inject, env): return PsnrInjector(opt_inject, env) elif type == 'batch_rotate': return BatchRotateInjector(opt_inject, env) + elif type == 'sr_diffs': + return SrDiffsInjector(opt_inject, env) else: raise NotImplementedError @@ -379,3 +381,25 @@ class BatchRotateInjector(Injector): img = state[self.input] return {self.output: torch.roll(img, 1, 0)} + +# Injector used to work with image deltas used in diff-SR +class SrDiffsInjector(Injector): + def __init__(self, opt, env): + super(SrDiffsInjector, self).__init__(opt, env) + self.mode = opt['mode'] + assert self.mode in ['recombine', 'produce_diff'] + self.lq = opt['lq'] + self.hq = opt['hq'] + if self.mode == 'produce_diff': + self.diff_key = opt['diff'] + + def forward(self, state): + resampled_lq = state[self.lq] + hq = state[self.hq] + if self.mode == 'produce_diff': + diff = hq - resampled_lq + return {self.output: torch.cat([resampled_lq, diff], dim=1), + self.diff_key: diff} + elif self.mode == 'recombine': + combined = resampled_lq + hq + return {self.output: combined} diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 2b696834..878a60f5 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -497,14 +497,21 @@ class StyleGan2DivergenceLoss(ConfigurableLoss): self.discriminator = opt['discriminator'] self.for_gen = opt['gen_loss'] self.gp_frequency = opt['gradient_penalty_frequency'] + self.noise = opt['noise'] if 'noise' in opt.keys() else 0 def forward(self, net, state): + real_input = state[self.real] + fake_input = state[self.fake] + if self.noise != 0: + fake_input = fake_input + torch.rand_like(fake_input) * self.noise + real_input = real_input + torch.rand_like(real_input) * self.noise + D = self.env['discriminators'][self.discriminator] - fake = D(state[self.fake]) + fake = D(fake_input) if self.for_gen: return fake.mean() else: - real_input = state[self.real].requires_grad_() # <-- Needed to compute gradients on the input. + real_input.requires_grad_() # <-- Needed to compute gradients on the input. real = D(real_input) divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() diff --git a/codes/train2.py b/codes/train2.py index 27d34b47..bcc65a45 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()