From 2225fe6ac222071f9dc3edbae3cc881209a8f7f9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 4 Jan 2021 10:57:09 -0700 Subject: [PATCH] Undo lucidrains changes for new discriminator This "new" code will live in the styledsr directory from now on. --- codes/models/stylegan/stylegan2_lucidrains.py | 50 ++----------------- 1 file changed, 3 insertions(+), 47 deletions(-) diff --git a/codes/models/stylegan/stylegan2_lucidrains.py b/codes/models/stylegan/stylegan2_lucidrains.py index 06517add..f61f2b0f 100644 --- a/codes/models/stylegan/stylegan2_lucidrains.py +++ b/codes/models/stylegan/stylegan2_lucidrains.py @@ -317,9 +317,6 @@ class StyleGan2Augmentor(nn.Module): return self.D(images) - def network_loaded(self): - self.D.network_loaded() - # stylegan2 classes @@ -741,7 +738,6 @@ class StyleGan2GeneratorWithLatent(nn.Module): class DiscriminatorBlock(nn.Module): def __init__(self, input_channels, filters, downsample=True): super().__init__() - self.filters = filters self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) self.net = nn.Sequential( @@ -767,7 +763,7 @@ class DiscriminatorBlock(nn.Module): class StyleGan2Discriminator(nn.Module): def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], - transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False): + transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False): super().__init__() num_layers = int(log2(image_size) - 1) @@ -809,12 +805,7 @@ class StyleGan2Discriminator(nn.Module): self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) self.flatten = Flatten() - if mlp: - self.to_logit = nn.Sequential(nn.Linear(latent_dim, 100), - leaky_relu(), - nn.Linear(100, 1)) - else: - self.to_logit = nn.Linear(latent_dim, 1) + self.to_logit = nn.Linear(latent_dim, 1) self._init_weights() @@ -849,38 +840,6 @@ class StyleGan2Discriminator(nn.Module): if type(m) in {nn.Conv2d, nn.Linear}: nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') - # Configures the network as partially pre-trained. This means: - # 1) The top (high-resolution) `num_blocks` will have their weights re-initialized. - # 2) The haed (linear layers) will also have their weights re-initialized - # 3) All intermediate blocks will be frozen until step `frozen_until_step` - # These settings will be applied after the weights have been loaded (network_loaded()) - def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0): - self.bypass_blocks = bypass_blocks - self.num_blocks = num_blocks - self.frozen_until_step = frozen_until_step - - # Called after the network weights are loaded. - def network_loaded(self): - if not hasattr(self, 'frozen_until_step'): - return - - if self.bypass_blocks > 0: - self.blocks = self.blocks[self.bypass_blocks:] - self.blocks[0] = DiscriminatorBlock(3, self.blocks[0].filters, downsample=True).to(next(self.parameters()).device) - - reset_blocks = [self.to_logit] - for i in range(self.num_blocks): - reset_blocks.append(self.blocks[i]) - for bl in reset_blocks: - for m in bl.modules(): - if type(m) in {nn.Conv2d, nn.Linear}: - nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') - for p in m.parameters(recurse=True): - p._NEW_BLOCK = True - for p in self.parameters(): - if not hasattr(p, '_NEW_BLOCK'): - p.DO_NOT_TRAIN_UNTIL = self.frozen_until_step - class StyleGan2DivergenceLoss(L.ConfigurableLoss): def __init__(self, opt, env): @@ -957,8 +916,5 @@ def register_stylegan2_discriminator(opt_net, opt): attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn, do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False), - quantize=opt_get(opt_net, ['quantize'], False), - mlp=opt_get(opt_net, ['mlp_head'], True)) - if 'use_partial_pretrained' in opt_net.keys(): - disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until']) + quantize=opt_get(opt_net, ['quantize'], False)) return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])