From 8a9f21565335e43e94261a554b352d452fff4665 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 18 Jul 2020 14:18:48 -0600 Subject: [PATCH] Huge set of mods to support progressive generator growth --- codes/models/SRGAN_model.py | 30 ++++++-- codes/models/archs/ProgressiveSrg_arch.py | 93 ++++++++++++++--------- codes/models/archs/arch_util.py | 5 +- codes/models/lr_scheduler.py | 20 +++++ codes/models/networks.py | 18 ++--- codes/train.py | 4 +- 6 files changed, 116 insertions(+), 54 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index c73a70a6..697c5575 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -103,12 +103,15 @@ class SRGANModel(BaseModel): # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] - for k, v in self.netG.named_parameters(): # can optimize for a part of the model - if v.requires_grad: - optim_params.append(v) - else: - if self.rank <= 0: - logger.warning('Params [{:s}] will not optimize.'.format(k)) + if train_opt['lr_scheme'] == 'ProgressiveMultiStepLR': + optim_params = self.netG.get_param_groups() + else: + for k, v in self.netG.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) @@ -148,6 +151,15 @@ class SRGANModel(BaseModel): gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'], force_lr=train_opt['force_lr'])) + elif train_opt['lr_scheme'] == 'ProgressiveMultiStepLR': + # Only supported when there are two optimizers: G and D. + assert len(self.optimizers) == 2 + self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_G, train_opt['gen_lr_steps'], + self.netG.module.get_progressive_starts(), + train_opt['lr_gamma'])) + self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'], + [0], + train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( @@ -492,6 +504,12 @@ class SRGANModel(BaseModel): self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor) self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor) + # Log learning rates. + for i, pg in enumerate(self.optimizer_G.param_groups): + self.add_log_entry('gen_lr_%i' % (i,), pg['lr']) + for i, pg in enumerate(self.optimizer_D.param_groups): + self.add_log_entry('disc_lr_%i' % (i,), pg['lr']) + if step % self.corruptor_swapout_steps == 0 and step > 0: self.load_random_corruptor() diff --git a/codes/models/archs/ProgressiveSrg_arch.py b/codes/models/archs/ProgressiveSrg_arch.py index f7a0c6a0..df09cdce 100644 --- a/codes/models/archs/ProgressiveSrg_arch.py +++ b/codes/models/archs/ProgressiveSrg_arch.py @@ -14,7 +14,7 @@ import torch.nn.functional as F # 4) start_step will need to get set properly when constructing these models, even when resuming - OR another method needs to be added to resume properly. class GrowingSRGBase(nn.Module): - def __init__(self, progressive_schedule, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts, + def __init__(self, progressive_step_schedule, switch_reductions, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts, trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False, start_step=0): super(GrowingSRGBase, self).__init__() @@ -29,9 +29,8 @@ class GrowingSRGBase(nn.Module): self.switch_processing_layers = switch_processing_layers self.trans_layers = trans_layers self.transformation_filters = transformation_filters - - self.switches = nn.ModuleList([]) - self.progressive_schedule = progressive_schedule + self.progressive_schedule = progressive_step_schedule + self.switch_reductions = switch_reductions # This lists the reductions for all switches (even ones not activated yet). self.growth_fade_in_per_step = 1 / growth_fade_in_steps self.transformation_counts = trans_counts self.init_temperature = initial_temp @@ -39,39 +38,64 @@ class GrowingSRGBase(nn.Module): self.attentions = None self.upsample_factor = upsample_factor self.add_noise_to_transform = add_scalable_noise_to_transforms - self.latest_step = 0 + self.start_step = start_step + self.latest_step = start_step + self.fades = [] assert self.upsample_factor == 2 or self.upsample_factor == 4 - for i, step in enumerate(progressive_schedule): - if step >= start_step: - self.add_layer(i + 1) - - def add_layer(self, reductions): - multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters, + switches = [] + for i, (step, reductions) in enumerate(zip(progressive_step_schedule, switch_reductions)): + multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters, reductions, self.switch_processing_layers, self.transformation_counts) - pretransform_fn = functools.partial(ConvBnLelu, self.transformation_filters, self.transformation_filters, norm=False, - bias=False, weight_init_factor=.1) - transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5), - self.transformation_filters, kernel_size=3, depth=self.trans_layers, - weight_init_factor=.1) - self.switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, - transform_block=transform_fn, - transform_count=self.transformation_counts, init_temp=self.init_temperature, - add_scalable_noise_to_transforms=self.add_noise_to_transform)) + pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False, + bias=False, weight_init_factor=.1) + transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5), + self.transformation_filters, kernel_size=3, depth=self.trans_layers, + weight_init_factor=.1) + switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, + transform_block=transform_fn, + transform_count=self.transformation_counts, init_temp=self.init_temperature, + add_scalable_noise_to_transforms=self.add_noise_to_transform, + attention_norm=False)) + self.progressive_switches = nn.ModuleList(switches) + + def get_param_groups(self): + param_groups = [] + base_param_group = [] + for k, v in self.named_parameters(): + if "progressive_switches" not in k and v.requires_grad: + base_param_group.append(v) + param_groups.append({'params': base_param_group}) + for i, sw in enumerate(self.progressive_switches): + sw_param_group = [] + for k, v in sw.named_parameters(): + if v.requires_grad: + sw_param_group.append(v) + param_groups.append({'params': sw_param_group}) + return param_groups + + def get_progressive_starts(self): + # The base param group starts at step 0, the rest are defined via progressive_switches. + return [0] + self.progressive_schedule def forward(self, x): x = self.initial_conv(x) self.attentions = [] - for i, sw in enumerate(self.switches): - fade_in = 1 + self.fades = [] + self.enabled_switches = 0 + for i, sw in enumerate(self.progressive_switches): + fade_in = 1 if self.progressive_schedule[i] == 0 else 0 if self.latest_step > 0 and self.progressive_schedule[i] != 0: switch_age = self.latest_step - self.progressive_schedule[i] fade_in = min(1, switch_age * self.growth_fade_in_per_step) - x, att = sw.forward(x, True, fixed_scale=fade_in) - self.attentions.append(att) + if fade_in > 0: + self.enabled_switches += 1 + x, att = sw.forward(x, True, fixed_scale=fade_in) + self.attentions.append(att) + self.fades.append(fade_in) x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) if self.upsample_factor > 2: @@ -81,18 +105,13 @@ class GrowingSRGBase(nn.Module): return x, x def update_for_step(self, step, experiments_path='.'): - self.latest_step = step - - # Add any new layers as spelled out by the schedule. - if step != 0: - for i, s in enumerate(self.progressive_schedule): - if s == step: - self.add_layer(i + 1) + self.latest_step = step + self.start_step # Set the temperature of the switches, per-layer. - for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.switches)): + for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.progressive_switches)): temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step - sw.set_temperature(self.init_temperature - temp_loss_per_step * (step - first_step)) + sw.set_temperature(min(self.init_temperature, + max(self.init_temperature - temp_loss_per_step * (step - first_step), 1))) # Save attention images. if self.attentions is not None and step % 50 == 0: @@ -106,10 +125,12 @@ class GrowingSRGBase(nn.Module): for i in range(len(means)): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] - val["switch_%i_temperature" % (i,)] = self.switches[i].switch.temperature + val["switch_%i_temperature" % (i,)] = self.progressive_switches[i].switch.temperature + for i, f in enumerate(self.fades): + val["switch_%i_fade" % (i,)] = f + val["enabled_switches"] = self.enabled_switches return val - class DiscriminatorDownsample(nn.Module): def __init__(self, base_filters, end_filters): self.conv0 = ConvGnLelu(base_filters, end_filters, kernel_size=3, bias=False) diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index c239bc60..fba27c98 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -296,7 +296,7 @@ class ConvBnLelu(nn.Module): ''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' class ConvGnLelu(nn.Module): - def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8): + def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1): super(ConvGnLelu, self).__init__() padding_map = {1: 0, 3: 1, 5: 2, 7: 3} assert kernel_size in padding_map.keys() @@ -315,6 +315,9 @@ class ConvGnLelu(nn.Module): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out', nonlinearity='leaky_relu' if self.lelu else 'linear') + m.weight.data *= weight_init_factor + if m.bias is not None: + m.bias.data.zero_() elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) diff --git a/codes/models/lr_scheduler.py b/codes/models/lr_scheduler.py index 6ef0aef3..63503cff 100644 --- a/codes/models/lr_scheduler.py +++ b/codes/models/lr_scheduler.py @@ -5,6 +5,26 @@ import torch from torch.optim.lr_scheduler import _LRScheduler +# This scheduler is specifically designed to modulate the learning rate of several different param groups configured +# by a generator or discriminator that slowly adds new stages one at a time, e.g. like progressive growing of GANs. +class ProgressiveMultiStepLR(_LRScheduler): + def __init__(self, optimizer, milestones, group_starts, gamma=0.1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.group_starts = group_starts + super(ProgressiveMultiStepLR, self).__init__(optimizer) + + def get_lr(self): + group_lrs = [] + assert len(self.optimizer.param_groups) == len(self.group_starts) + for group, group_start in zip(self.optimizer.param_groups, self.group_starts): + if self.last_epoch - group_start not in self.milestones: + group_lrs.append(group['lr']) + else: + group_lrs.append(group['lr'] * self.gamma) + return group_lrs + + class MultiStepLR_Restart(_LRScheduler): def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, force_lr=False, last_epoch=-1): diff --git a/codes/models/networks.py b/codes/models/networks.py index 432cbe1c..1bca11e3 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -10,6 +10,7 @@ import models.archs.NestedSwitchGenerator as ng import models.archs.feature_arch as feature_arch import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.archs.SRG1_arch as srg1 +import models.archs.ProgressiveSrg_arch as psrg import functools # Generator @@ -78,15 +79,14 @@ def define_G(opt, net_key='network_G'): initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) - elif which_model == "DualOutputSRG": - netG = SwitchedGen_arch.DualOutputSRG(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'], - switch_reductions=opt_net['switch_reductions'], - switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], - trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], - transformation_filters=opt_net['transformation_filters'], - initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], - heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], - upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) + elif which_model == "ProgressiveSRG2": + netG = psrg.GrowingSRGBase(progressive_step_schedule=opt_net['schedule'], switch_reductions=opt_net['reductions'], + growth_fade_in_steps=opt_net['fade_in_steps'], switch_filters=opt_net['switch_filters'], + switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], + trans_layers=opt_net['trans_layers'], transformation_filters=opt_net['transformation_filters'], + initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], + upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'], + start_step=opt_net['start_step']) # image corruption elif which_model == 'HighToLowResNet': diff --git a/codes/train.py b/codes/train.py index 99fc7c14..f879a0af 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_imgset_pixgan_srg2/train_imgset_pixgan_srg2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_progressive_srg2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -161,7 +161,7 @@ def main(): current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: - current_step = -1 + current_step = 0 start_epoch = 0 #### training