forked from mrq/DL-Art-School
Huge set of mods to support progressive generator growth
This commit is contained in:
parent
47a525241f
commit
8a9f215653
|
@ -103,12 +103,15 @@ class SRGANModel(BaseModel):
|
|||
# G
|
||||
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||
optim_params = []
|
||||
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
optim_params.append(v)
|
||||
else:
|
||||
if self.rank <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
if train_opt['lr_scheme'] == 'ProgressiveMultiStepLR':
|
||||
optim_params = self.netG.get_param_groups()
|
||||
else:
|
||||
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
optim_params.append(v)
|
||||
else:
|
||||
if self.rank <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
||||
weight_decay=wd_G,
|
||||
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
||||
|
@ -148,6 +151,15 @@ class SRGANModel(BaseModel):
|
|||
gamma=train_opt['lr_gamma'],
|
||||
clear_state=train_opt['clear_state'],
|
||||
force_lr=train_opt['force_lr']))
|
||||
elif train_opt['lr_scheme'] == 'ProgressiveMultiStepLR':
|
||||
# Only supported when there are two optimizers: G and D.
|
||||
assert len(self.optimizers) == 2
|
||||
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_G, train_opt['gen_lr_steps'],
|
||||
self.netG.module.get_progressive_starts(),
|
||||
train_opt['lr_gamma']))
|
||||
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'],
|
||||
[0],
|
||||
train_opt['lr_gamma']))
|
||||
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||
for optimizer in self.optimizers:
|
||||
self.schedulers.append(
|
||||
|
@ -492,6 +504,12 @@ class SRGANModel(BaseModel):
|
|||
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
|
||||
|
||||
# Log learning rates.
|
||||
for i, pg in enumerate(self.optimizer_G.param_groups):
|
||||
self.add_log_entry('gen_lr_%i' % (i,), pg['lr'])
|
||||
for i, pg in enumerate(self.optimizer_D.param_groups):
|
||||
self.add_log_entry('disc_lr_%i' % (i,), pg['lr'])
|
||||
|
||||
if step % self.corruptor_swapout_steps == 0 and step > 0:
|
||||
self.load_random_corruptor()
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ import torch.nn.functional as F
|
|||
# 4) start_step will need to get set properly when constructing these models, even when resuming - OR another method needs to be added to resume properly.
|
||||
|
||||
class GrowingSRGBase(nn.Module):
|
||||
def __init__(self, progressive_schedule, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts,
|
||||
def __init__(self, progressive_step_schedule, switch_reductions, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts,
|
||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, upsample_factor=1,
|
||||
add_scalable_noise_to_transforms=False, start_step=0):
|
||||
super(GrowingSRGBase, self).__init__()
|
||||
|
@ -29,9 +29,8 @@ class GrowingSRGBase(nn.Module):
|
|||
self.switch_processing_layers = switch_processing_layers
|
||||
self.trans_layers = trans_layers
|
||||
self.transformation_filters = transformation_filters
|
||||
|
||||
self.switches = nn.ModuleList([])
|
||||
self.progressive_schedule = progressive_schedule
|
||||
self.progressive_schedule = progressive_step_schedule
|
||||
self.switch_reductions = switch_reductions # This lists the reductions for all switches (even ones not activated yet).
|
||||
self.growth_fade_in_per_step = 1 / growth_fade_in_steps
|
||||
self.transformation_counts = trans_counts
|
||||
self.init_temperature = initial_temp
|
||||
|
@ -39,39 +38,64 @@ class GrowingSRGBase(nn.Module):
|
|||
self.attentions = None
|
||||
self.upsample_factor = upsample_factor
|
||||
self.add_noise_to_transform = add_scalable_noise_to_transforms
|
||||
self.latest_step = 0
|
||||
self.start_step = start_step
|
||||
self.latest_step = start_step
|
||||
self.fades = []
|
||||
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
||||
|
||||
for i, step in enumerate(progressive_schedule):
|
||||
if step >= start_step:
|
||||
self.add_layer(i + 1)
|
||||
|
||||
def add_layer(self, reductions):
|
||||
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
|
||||
switches = []
|
||||
for i, (step, reductions) in enumerate(zip(progressive_step_schedule, switch_reductions)):
|
||||
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
|
||||
reductions, self.switch_processing_layers, self.transformation_counts)
|
||||
pretransform_fn = functools.partial(ConvBnLelu, self.transformation_filters, self.transformation_filters, norm=False,
|
||||
bias=False, weight_init_factor=.1)
|
||||
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
|
||||
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
|
||||
weight_init_factor=.1)
|
||||
self.switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn,
|
||||
transform_block=transform_fn,
|
||||
transform_count=self.transformation_counts, init_temp=self.init_temperature,
|
||||
add_scalable_noise_to_transforms=self.add_noise_to_transform))
|
||||
pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
|
||||
bias=False, weight_init_factor=.1)
|
||||
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
|
||||
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
|
||||
weight_init_factor=.1)
|
||||
switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn,
|
||||
transform_block=transform_fn,
|
||||
transform_count=self.transformation_counts, init_temp=self.init_temperature,
|
||||
add_scalable_noise_to_transforms=self.add_noise_to_transform,
|
||||
attention_norm=False))
|
||||
self.progressive_switches = nn.ModuleList(switches)
|
||||
|
||||
def get_param_groups(self):
|
||||
param_groups = []
|
||||
base_param_group = []
|
||||
for k, v in self.named_parameters():
|
||||
if "progressive_switches" not in k and v.requires_grad:
|
||||
base_param_group.append(v)
|
||||
param_groups.append({'params': base_param_group})
|
||||
for i, sw in enumerate(self.progressive_switches):
|
||||
sw_param_group = []
|
||||
for k, v in sw.named_parameters():
|
||||
if v.requires_grad:
|
||||
sw_param_group.append(v)
|
||||
param_groups.append({'params': sw_param_group})
|
||||
return param_groups
|
||||
|
||||
def get_progressive_starts(self):
|
||||
# The base param group starts at step 0, the rest are defined via progressive_switches.
|
||||
return [0] + self.progressive_schedule
|
||||
|
||||
def forward(self, x):
|
||||
x = self.initial_conv(x)
|
||||
|
||||
self.attentions = []
|
||||
for i, sw in enumerate(self.switches):
|
||||
fade_in = 1
|
||||
self.fades = []
|
||||
self.enabled_switches = 0
|
||||
for i, sw in enumerate(self.progressive_switches):
|
||||
fade_in = 1 if self.progressive_schedule[i] == 0 else 0
|
||||
if self.latest_step > 0 and self.progressive_schedule[i] != 0:
|
||||
switch_age = self.latest_step - self.progressive_schedule[i]
|
||||
fade_in = min(1, switch_age * self.growth_fade_in_per_step)
|
||||
|
||||
x, att = sw.forward(x, True, fixed_scale=fade_in)
|
||||
self.attentions.append(att)
|
||||
if fade_in > 0:
|
||||
self.enabled_switches += 1
|
||||
x, att = sw.forward(x, True, fixed_scale=fade_in)
|
||||
self.attentions.append(att)
|
||||
self.fades.append(fade_in)
|
||||
|
||||
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
|
||||
if self.upsample_factor > 2:
|
||||
|
@ -81,18 +105,13 @@ class GrowingSRGBase(nn.Module):
|
|||
return x, x
|
||||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
self.latest_step = step
|
||||
|
||||
# Add any new layers as spelled out by the schedule.
|
||||
if step != 0:
|
||||
for i, s in enumerate(self.progressive_schedule):
|
||||
if s == step:
|
||||
self.add_layer(i + 1)
|
||||
self.latest_step = step + self.start_step
|
||||
|
||||
# Set the temperature of the switches, per-layer.
|
||||
for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.switches)):
|
||||
for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.progressive_switches)):
|
||||
temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
|
||||
sw.set_temperature(self.init_temperature - temp_loss_per_step * (step - first_step))
|
||||
sw.set_temperature(min(self.init_temperature,
|
||||
max(self.init_temperature - temp_loss_per_step * (step - first_step), 1)))
|
||||
|
||||
# Save attention images.
|
||||
if self.attentions is not None and step % 50 == 0:
|
||||
|
@ -106,10 +125,12 @@ class GrowingSRGBase(nn.Module):
|
|||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
val["switch_%i_temperature" % (i,)] = self.switches[i].switch.temperature
|
||||
val["switch_%i_temperature" % (i,)] = self.progressive_switches[i].switch.temperature
|
||||
for i, f in enumerate(self.fades):
|
||||
val["switch_%i_fade" % (i,)] = f
|
||||
val["enabled_switches"] = self.enabled_switches
|
||||
return val
|
||||
|
||||
|
||||
class DiscriminatorDownsample(nn.Module):
|
||||
def __init__(self, base_filters, end_filters):
|
||||
self.conv0 = ConvGnLelu(base_filters, end_filters, kernel_size=3, bias=False)
|
||||
|
|
|
@ -296,7 +296,7 @@ class ConvBnLelu(nn.Module):
|
|||
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
class ConvGnLelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
|
||||
super(ConvGnLelu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
|
@ -315,6 +315,9 @@ class ConvGnLelu(nn.Module):
|
|||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
||||
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
||||
m.weight.data *= weight_init_factor
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
|
|
@ -5,6 +5,26 @@ import torch
|
|||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
# This scheduler is specifically designed to modulate the learning rate of several different param groups configured
|
||||
# by a generator or discriminator that slowly adds new stages one at a time, e.g. like progressive growing of GANs.
|
||||
class ProgressiveMultiStepLR(_LRScheduler):
|
||||
def __init__(self, optimizer, milestones, group_starts, gamma=0.1):
|
||||
self.milestones = Counter(milestones)
|
||||
self.gamma = gamma
|
||||
self.group_starts = group_starts
|
||||
super(ProgressiveMultiStepLR, self).__init__(optimizer)
|
||||
|
||||
def get_lr(self):
|
||||
group_lrs = []
|
||||
assert len(self.optimizer.param_groups) == len(self.group_starts)
|
||||
for group, group_start in zip(self.optimizer.param_groups, self.group_starts):
|
||||
if self.last_epoch - group_start not in self.milestones:
|
||||
group_lrs.append(group['lr'])
|
||||
else:
|
||||
group_lrs.append(group['lr'] * self.gamma)
|
||||
return group_lrs
|
||||
|
||||
|
||||
class MultiStepLR_Restart(_LRScheduler):
|
||||
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
|
||||
clear_state=False, force_lr=False, last_epoch=-1):
|
||||
|
|
|
@ -10,6 +10,7 @@ import models.archs.NestedSwitchGenerator as ng
|
|||
import models.archs.feature_arch as feature_arch
|
||||
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||
import models.archs.SRG1_arch as srg1
|
||||
import models.archs.ProgressiveSrg_arch as psrg
|
||||
import functools
|
||||
|
||||
# Generator
|
||||
|
@ -78,15 +79,14 @@ def define_G(opt, net_key='network_G'):
|
|||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == "DualOutputSRG":
|
||||
netG = SwitchedGen_arch.DualOutputSRG(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||
transformation_filters=opt_net['transformation_filters'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == "ProgressiveSRG2":
|
||||
netG = psrg.GrowingSRGBase(progressive_step_schedule=opt_net['schedule'], switch_reductions=opt_net['reductions'],
|
||||
growth_fade_in_steps=opt_net['fade_in_steps'], switch_filters=opt_net['switch_filters'],
|
||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||
trans_layers=opt_net['trans_layers'], transformation_filters=opt_net['transformation_filters'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
|
||||
start_step=opt_net['start_step'])
|
||||
|
||||
# image corruption
|
||||
elif which_model == 'HighToLowResNet':
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_imgset_pixgan_srg2/train_imgset_pixgan_srg2.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_progressive_srg2.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -161,7 +161,7 @@ def main():
|
|||
current_step = resume_state['iter']
|
||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||
else:
|
||||
current_step = -1
|
||||
current_step = 0
|
||||
start_epoch = 0
|
||||
|
||||
#### training
|
||||
|
|
Loading…
Reference in New Issue
Block a user