diff --git a/codes/models/stylegan/stylegan2_lucidrains.py b/codes/models/stylegan/stylegan2_lucidrains.py index f61f2b0f..a10f890a 100644 --- a/codes/models/stylegan/stylegan2_lucidrains.py +++ b/codes/models/stylegan/stylegan2_lucidrains.py @@ -317,6 +317,9 @@ class StyleGan2Augmentor(nn.Module): return self.D(images) + def network_loaded(self): + self.D.network_loaded() + # stylegan2 classes @@ -738,6 +741,7 @@ class StyleGan2GeneratorWithLatent(nn.Module): class DiscriminatorBlock(nn.Module): def __init__(self, input_channels, filters, downsample=True): super().__init__() + self.filters = filters self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) self.net = nn.Sequential( @@ -763,7 +767,7 @@ class DiscriminatorBlock(nn.Module): class StyleGan2Discriminator(nn.Module): def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], - transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False): + transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False): super().__init__() num_layers = int(log2(image_size) - 1) @@ -805,7 +809,11 @@ class StyleGan2Discriminator(nn.Module): self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) self.flatten = Flatten() - self.to_logit = nn.Linear(latent_dim, 1) + if mlp: + self.to_logit = nn.Sequential(nn.Linear(latent_dim, 100), + nn.Linear(100, 1)) + else: + self.to_logit = nn.Linear(latent_dim, 1) self._init_weights() @@ -840,6 +848,38 @@ class StyleGan2Discriminator(nn.Module): if type(m) in {nn.Conv2d, nn.Linear}: nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + # Configures the network as partially pre-trained. This means: + # 1) The top (high-resolution) `num_blocks` will have their weights re-initialized. + # 2) The haed (linear layers) will also have their weights re-initialized + # 3) All intermediate blocks will be frozen until step `frozen_until_step` + # These settings will be applied after the weights have been loaded (network_loaded()) + def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0): + self.bypass_blocks = bypass_blocks + self.num_blocks = num_blocks + self.frozen_until_step = frozen_until_step + + # Called after the network weights are loaded. + def network_loaded(self): + if not hasattr(self, 'frozen_until_step'): + return + + if self.bypass_blocks > 0: + self.blocks = self.blocks[self.bypass_blocks:] + self.blocks[0] = DiscriminatorBlock(3, self.blocks[0].filters, downsample=True).to(next(self.parameters()).device) + + reset_blocks = [self.to_logit] + for i in range(self.num_blocks): + reset_blocks.append(self.blocks[i]) + for bl in reset_blocks: + for m in bl.modules(): + if type(m) in {nn.Conv2d, nn.Linear}: + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + for p in m.parameters(recurse=True): + p._NEW_BLOCK = True + for p in self.parameters(): + if not hasattr(p, '_NEW_BLOCK'): + p.DO_NOT_TRAIN_UNTIL = self.frozen_until_step + class StyleGan2DivergenceLoss(L.ConfigurableLoss): def __init__(self, opt, env): @@ -916,5 +956,8 @@ def register_stylegan2_discriminator(opt_net, opt): attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn, do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False), - quantize=opt_get(opt_net, ['quantize'], False)) + quantize=opt_get(opt_net, ['quantize'], False), + mlp=opt_get(opt_net, ['mlp_head'], True)) + if 'use_partial_pretrained' in opt_net.keys(): + disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until']) return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) diff --git a/codes/scripts/byol_extract_wrapped_model.py b/codes/scripts/byol_extract_wrapped_model.py index 12818d31..0b5147c4 100644 --- a/codes/scripts/byol_extract_wrapped_model.py +++ b/codes/scripts/byol_extract_wrapped_model.py @@ -3,8 +3,8 @@ import torch from models.spinenet_arch import SpineNet if __name__ == '__main__': - pretrained_path = '../../experiments/train_sbyol_512unsupervised_restart/models/48000_generator.pth' - output_path = '../../experiments/spinenet49_imgset_sbyol.pth' + pretrained_path = '../../experiments/byol_discriminator.pth' + output_path = '../../experiments/byol_discriminator_extracted.pth' wrap_key = 'online_encoder.net.' sd = torch.load(pretrained_path) @@ -13,8 +13,8 @@ if __name__ == '__main__': if wrap_key in k: sdo[k.replace(wrap_key, '')] = v - model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda') - model.load_state_dict(sdo, strict=True) + #model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda') + #model.load_state_dict(sdo, strict=True) print("Validation succeeded, dumping state dict to output path.") torch.save(sdo, output_path) \ No newline at end of file diff --git a/codes/scripts/byol_resnet_playground.py b/codes/scripts/byol_resnet_playground.py index a5491fd5..c21115e9 100644 --- a/codes/scripts/byol_resnet_playground.py +++ b/codes/scripts/byol_resnet_playground.py @@ -106,7 +106,7 @@ def get_latent_for_img(model, img): def find_similar_latents(model, compare_fn=structural_euc_dist): global layer_hooked_value - img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\poon.jpg' + img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg' #img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg' output_path = '../../results/byol_resnet_similars' os.makedirs(output_path, exist_ok=True) @@ -141,7 +141,7 @@ def find_similar_latents(model, compare_fn=structural_euc_dist): if __name__ == '__main__': - pretrained_path = '../../experiments/resnet_byol_diffframe_85k.pth' + pretrained_path = '../../experiments/resnet_byol_diffframe_115k.pth' model = resnet50(pretrained=False).to('cuda') sd = torch.load(pretrained_path) resnet_sd = {} diff --git a/codes/train.py b/codes/train.py index 9a8746fa..e57e4152 100644 --- a/codes/train.py +++ b/codes/train.py @@ -295,7 +295,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_discriminator_diffimage.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_stylesr.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index db25b686..33db1894 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -209,7 +209,8 @@ class ExtensibleTrainer(BaseModel): if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']: net_enabled = False for p in net.parameters(): - if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"): + do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and step < p.DO_NOT_TRAIN_UNTIL) + if p.dtype != torch.int64 and p.dtype != torch.bool and not do_not_train_flag: p.requires_grad = net_enabled else: p.requires_grad = False @@ -357,6 +358,8 @@ class ExtensibleTrainer(BaseModel): if self.rank <= 0: logger.info('Loading model for [%s]' % (load_path,)) self.load_network(load_path, net, self.opt['path']['strict_load']) + if hasattr(net.module, 'network_loaded'): + net.module.network_loaded() def save(self, iter_step): for name, net in self.networks.items(): diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 75a5d44d..4d0a6ede 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -58,6 +58,7 @@ class BaseModel(): def update_learning_rate(self, cur_iter, warmup_iter=-1): for scheduler in self.schedulers: + scheduler.last_epoch = cur_iter scheduler.step() # set up warm-up learning rate if cur_iter < warmup_iter: