Huge set of mods to support progressive generator growth

This commit is contained in:
James Betker 2020-07-18 14:18:48 -06:00
parent 47a525241f
commit 8a9f215653
6 changed files with 116 additions and 54 deletions

View File

@ -103,12 +103,15 @@ class SRGANModel(BaseModel):
# G
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
optim_params = []
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
if train_opt['lr_scheme'] == 'ProgressiveMultiStepLR':
optim_params = self.netG.get_param_groups()
else:
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G,
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
@ -148,6 +151,15 @@ class SRGANModel(BaseModel):
gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state'],
force_lr=train_opt['force_lr']))
elif train_opt['lr_scheme'] == 'ProgressiveMultiStepLR':
# Only supported when there are two optimizers: G and D.
assert len(self.optimizers) == 2
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_G, train_opt['gen_lr_steps'],
self.netG.module.get_progressive_starts(),
train_opt['lr_gamma']))
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'],
[0],
train_opt['lr_gamma']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers:
self.schedulers.append(
@ -492,6 +504,12 @@ class SRGANModel(BaseModel):
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
# Log learning rates.
for i, pg in enumerate(self.optimizer_G.param_groups):
self.add_log_entry('gen_lr_%i' % (i,), pg['lr'])
for i, pg in enumerate(self.optimizer_D.param_groups):
self.add_log_entry('disc_lr_%i' % (i,), pg['lr'])
if step % self.corruptor_swapout_steps == 0 and step > 0:
self.load_random_corruptor()

View File

@ -14,7 +14,7 @@ import torch.nn.functional as F
# 4) start_step will need to get set properly when constructing these models, even when resuming - OR another method needs to be added to resume properly.
class GrowingSRGBase(nn.Module):
def __init__(self, progressive_schedule, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts,
def __init__(self, progressive_step_schedule, switch_reductions, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, upsample_factor=1,
add_scalable_noise_to_transforms=False, start_step=0):
super(GrowingSRGBase, self).__init__()
@ -29,9 +29,8 @@ class GrowingSRGBase(nn.Module):
self.switch_processing_layers = switch_processing_layers
self.trans_layers = trans_layers
self.transformation_filters = transformation_filters
self.switches = nn.ModuleList([])
self.progressive_schedule = progressive_schedule
self.progressive_schedule = progressive_step_schedule
self.switch_reductions = switch_reductions # This lists the reductions for all switches (even ones not activated yet).
self.growth_fade_in_per_step = 1 / growth_fade_in_steps
self.transformation_counts = trans_counts
self.init_temperature = initial_temp
@ -39,39 +38,64 @@ class GrowingSRGBase(nn.Module):
self.attentions = None
self.upsample_factor = upsample_factor
self.add_noise_to_transform = add_scalable_noise_to_transforms
self.latest_step = 0
self.start_step = start_step
self.latest_step = start_step
self.fades = []
assert self.upsample_factor == 2 or self.upsample_factor == 4
for i, step in enumerate(progressive_schedule):
if step >= start_step:
self.add_layer(i + 1)
def add_layer(self, reductions):
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
switches = []
for i, (step, reductions) in enumerate(zip(progressive_step_schedule, switch_reductions)):
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
reductions, self.switch_processing_layers, self.transformation_counts)
pretransform_fn = functools.partial(ConvBnLelu, self.transformation_filters, self.transformation_filters, norm=False,
bias=False, weight_init_factor=.1)
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
weight_init_factor=.1)
self.switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn,
transform_block=transform_fn,
transform_count=self.transformation_counts, init_temp=self.init_temperature,
add_scalable_noise_to_transforms=self.add_noise_to_transform))
pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
bias=False, weight_init_factor=.1)
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
weight_init_factor=.1)
switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn,
transform_block=transform_fn,
transform_count=self.transformation_counts, init_temp=self.init_temperature,
add_scalable_noise_to_transforms=self.add_noise_to_transform,
attention_norm=False))
self.progressive_switches = nn.ModuleList(switches)
def get_param_groups(self):
param_groups = []
base_param_group = []
for k, v in self.named_parameters():
if "progressive_switches" not in k and v.requires_grad:
base_param_group.append(v)
param_groups.append({'params': base_param_group})
for i, sw in enumerate(self.progressive_switches):
sw_param_group = []
for k, v in sw.named_parameters():
if v.requires_grad:
sw_param_group.append(v)
param_groups.append({'params': sw_param_group})
return param_groups
def get_progressive_starts(self):
# The base param group starts at step 0, the rest are defined via progressive_switches.
return [0] + self.progressive_schedule
def forward(self, x):
x = self.initial_conv(x)
self.attentions = []
for i, sw in enumerate(self.switches):
fade_in = 1
self.fades = []
self.enabled_switches = 0
for i, sw in enumerate(self.progressive_switches):
fade_in = 1 if self.progressive_schedule[i] == 0 else 0
if self.latest_step > 0 and self.progressive_schedule[i] != 0:
switch_age = self.latest_step - self.progressive_schedule[i]
fade_in = min(1, switch_age * self.growth_fade_in_per_step)
x, att = sw.forward(x, True, fixed_scale=fade_in)
self.attentions.append(att)
if fade_in > 0:
self.enabled_switches += 1
x, att = sw.forward(x, True, fixed_scale=fade_in)
self.attentions.append(att)
self.fades.append(fade_in)
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
if self.upsample_factor > 2:
@ -81,18 +105,13 @@ class GrowingSRGBase(nn.Module):
return x, x
def update_for_step(self, step, experiments_path='.'):
self.latest_step = step
# Add any new layers as spelled out by the schedule.
if step != 0:
for i, s in enumerate(self.progressive_schedule):
if s == step:
self.add_layer(i + 1)
self.latest_step = step + self.start_step
# Set the temperature of the switches, per-layer.
for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.switches)):
for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.progressive_switches)):
temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
sw.set_temperature(self.init_temperature - temp_loss_per_step * (step - first_step))
sw.set_temperature(min(self.init_temperature,
max(self.init_temperature - temp_loss_per_step * (step - first_step), 1)))
# Save attention images.
if self.attentions is not None and step % 50 == 0:
@ -106,10 +125,12 @@ class GrowingSRGBase(nn.Module):
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
val["switch_%i_temperature" % (i,)] = self.switches[i].switch.temperature
val["switch_%i_temperature" % (i,)] = self.progressive_switches[i].switch.temperature
for i, f in enumerate(self.fades):
val["switch_%i_fade" % (i,)] = f
val["enabled_switches"] = self.enabled_switches
return val
class DiscriminatorDownsample(nn.Module):
def __init__(self, base_filters, end_filters):
self.conv0 = ConvGnLelu(base_filters, end_filters, kernel_size=3, bias=False)

View File

@ -296,7 +296,7 @@ class ConvBnLelu(nn.Module):
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvGnLelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
super(ConvGnLelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
@ -315,6 +315,9 @@ class ConvGnLelu(nn.Module):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
nonlinearity='leaky_relu' if self.lelu else 'linear')
m.weight.data *= weight_init_factor
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

View File

@ -5,6 +5,26 @@ import torch
from torch.optim.lr_scheduler import _LRScheduler
# This scheduler is specifically designed to modulate the learning rate of several different param groups configured
# by a generator or discriminator that slowly adds new stages one at a time, e.g. like progressive growing of GANs.
class ProgressiveMultiStepLR(_LRScheduler):
def __init__(self, optimizer, milestones, group_starts, gamma=0.1):
self.milestones = Counter(milestones)
self.gamma = gamma
self.group_starts = group_starts
super(ProgressiveMultiStepLR, self).__init__(optimizer)
def get_lr(self):
group_lrs = []
assert len(self.optimizer.param_groups) == len(self.group_starts)
for group, group_start in zip(self.optimizer.param_groups, self.group_starts):
if self.last_epoch - group_start not in self.milestones:
group_lrs.append(group['lr'])
else:
group_lrs.append(group['lr'] * self.gamma)
return group_lrs
class MultiStepLR_Restart(_LRScheduler):
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
clear_state=False, force_lr=False, last_epoch=-1):

View File

@ -10,6 +10,7 @@ import models.archs.NestedSwitchGenerator as ng
import models.archs.feature_arch as feature_arch
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.archs.SRG1_arch as srg1
import models.archs.ProgressiveSrg_arch as psrg
import functools
# Generator
@ -78,15 +79,14 @@ def define_G(opt, net_key='network_G'):
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "DualOutputSRG":
netG = SwitchedGen_arch.DualOutputSRG(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
transformation_filters=opt_net['transformation_filters'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "ProgressiveSRG2":
netG = psrg.GrowingSRGBase(progressive_step_schedule=opt_net['schedule'], switch_reductions=opt_net['reductions'],
growth_fade_in_steps=opt_net['fade_in_steps'], switch_filters=opt_net['switch_filters'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_layers=opt_net['trans_layers'], transformation_filters=opt_net['transformation_filters'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
start_step=opt_net['start_step'])
# image corruption
elif which_model == 'HighToLowResNet':

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_imgset_pixgan_srg2/train_imgset_pixgan_srg2.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_progressive_srg2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
@ -161,7 +161,7 @@ def main():
current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers
else:
current_step = -1
current_step = 0
start_epoch = 0
#### training