forked from mrq/DL-Art-School
Huge set of mods to support progressive generator growth
This commit is contained in:
parent
47a525241f
commit
8a9f215653
|
@ -103,12 +103,15 @@ class SRGANModel(BaseModel):
|
||||||
# G
|
# G
|
||||||
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||||
optim_params = []
|
optim_params = []
|
||||||
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
if train_opt['lr_scheme'] == 'ProgressiveMultiStepLR':
|
||||||
if v.requires_grad:
|
optim_params = self.netG.get_param_groups()
|
||||||
optim_params.append(v)
|
else:
|
||||||
else:
|
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
||||||
if self.rank <= 0:
|
if v.requires_grad:
|
||||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
optim_params.append(v)
|
||||||
|
else:
|
||||||
|
if self.rank <= 0:
|
||||||
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
||||||
weight_decay=wd_G,
|
weight_decay=wd_G,
|
||||||
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
||||||
|
@ -148,6 +151,15 @@ class SRGANModel(BaseModel):
|
||||||
gamma=train_opt['lr_gamma'],
|
gamma=train_opt['lr_gamma'],
|
||||||
clear_state=train_opt['clear_state'],
|
clear_state=train_opt['clear_state'],
|
||||||
force_lr=train_opt['force_lr']))
|
force_lr=train_opt['force_lr']))
|
||||||
|
elif train_opt['lr_scheme'] == 'ProgressiveMultiStepLR':
|
||||||
|
# Only supported when there are two optimizers: G and D.
|
||||||
|
assert len(self.optimizers) == 2
|
||||||
|
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_G, train_opt['gen_lr_steps'],
|
||||||
|
self.netG.module.get_progressive_starts(),
|
||||||
|
train_opt['lr_gamma']))
|
||||||
|
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'],
|
||||||
|
[0],
|
||||||
|
train_opt['lr_gamma']))
|
||||||
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||||
for optimizer in self.optimizers:
|
for optimizer in self.optimizers:
|
||||||
self.schedulers.append(
|
self.schedulers.append(
|
||||||
|
@ -492,6 +504,12 @@ class SRGANModel(BaseModel):
|
||||||
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
|
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
|
||||||
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
|
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
|
||||||
|
|
||||||
|
# Log learning rates.
|
||||||
|
for i, pg in enumerate(self.optimizer_G.param_groups):
|
||||||
|
self.add_log_entry('gen_lr_%i' % (i,), pg['lr'])
|
||||||
|
for i, pg in enumerate(self.optimizer_D.param_groups):
|
||||||
|
self.add_log_entry('disc_lr_%i' % (i,), pg['lr'])
|
||||||
|
|
||||||
if step % self.corruptor_swapout_steps == 0 and step > 0:
|
if step % self.corruptor_swapout_steps == 0 and step > 0:
|
||||||
self.load_random_corruptor()
|
self.load_random_corruptor()
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ import torch.nn.functional as F
|
||||||
# 4) start_step will need to get set properly when constructing these models, even when resuming - OR another method needs to be added to resume properly.
|
# 4) start_step will need to get set properly when constructing these models, even when resuming - OR another method needs to be added to resume properly.
|
||||||
|
|
||||||
class GrowingSRGBase(nn.Module):
|
class GrowingSRGBase(nn.Module):
|
||||||
def __init__(self, progressive_schedule, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts,
|
def __init__(self, progressive_step_schedule, switch_reductions, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts,
|
||||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, upsample_factor=1,
|
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, upsample_factor=1,
|
||||||
add_scalable_noise_to_transforms=False, start_step=0):
|
add_scalable_noise_to_transforms=False, start_step=0):
|
||||||
super(GrowingSRGBase, self).__init__()
|
super(GrowingSRGBase, self).__init__()
|
||||||
|
@ -29,9 +29,8 @@ class GrowingSRGBase(nn.Module):
|
||||||
self.switch_processing_layers = switch_processing_layers
|
self.switch_processing_layers = switch_processing_layers
|
||||||
self.trans_layers = trans_layers
|
self.trans_layers = trans_layers
|
||||||
self.transformation_filters = transformation_filters
|
self.transformation_filters = transformation_filters
|
||||||
|
self.progressive_schedule = progressive_step_schedule
|
||||||
self.switches = nn.ModuleList([])
|
self.switch_reductions = switch_reductions # This lists the reductions for all switches (even ones not activated yet).
|
||||||
self.progressive_schedule = progressive_schedule
|
|
||||||
self.growth_fade_in_per_step = 1 / growth_fade_in_steps
|
self.growth_fade_in_per_step = 1 / growth_fade_in_steps
|
||||||
self.transformation_counts = trans_counts
|
self.transformation_counts = trans_counts
|
||||||
self.init_temperature = initial_temp
|
self.init_temperature = initial_temp
|
||||||
|
@ -39,39 +38,64 @@ class GrowingSRGBase(nn.Module):
|
||||||
self.attentions = None
|
self.attentions = None
|
||||||
self.upsample_factor = upsample_factor
|
self.upsample_factor = upsample_factor
|
||||||
self.add_noise_to_transform = add_scalable_noise_to_transforms
|
self.add_noise_to_transform = add_scalable_noise_to_transforms
|
||||||
self.latest_step = 0
|
self.start_step = start_step
|
||||||
|
self.latest_step = start_step
|
||||||
|
self.fades = []
|
||||||
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
||||||
|
|
||||||
for i, step in enumerate(progressive_schedule):
|
switches = []
|
||||||
if step >= start_step:
|
for i, (step, reductions) in enumerate(zip(progressive_step_schedule, switch_reductions)):
|
||||||
self.add_layer(i + 1)
|
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
|
||||||
|
|
||||||
def add_layer(self, reductions):
|
|
||||||
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
|
|
||||||
reductions, self.switch_processing_layers, self.transformation_counts)
|
reductions, self.switch_processing_layers, self.transformation_counts)
|
||||||
pretransform_fn = functools.partial(ConvBnLelu, self.transformation_filters, self.transformation_filters, norm=False,
|
pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
|
||||||
bias=False, weight_init_factor=.1)
|
bias=False, weight_init_factor=.1)
|
||||||
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
|
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
|
||||||
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
|
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
|
||||||
weight_init_factor=.1)
|
weight_init_factor=.1)
|
||||||
self.switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
|
switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn,
|
pre_transform_block=pretransform_fn,
|
||||||
transform_block=transform_fn,
|
transform_block=transform_fn,
|
||||||
transform_count=self.transformation_counts, init_temp=self.init_temperature,
|
transform_count=self.transformation_counts, init_temp=self.init_temperature,
|
||||||
add_scalable_noise_to_transforms=self.add_noise_to_transform))
|
add_scalable_noise_to_transforms=self.add_noise_to_transform,
|
||||||
|
attention_norm=False))
|
||||||
|
self.progressive_switches = nn.ModuleList(switches)
|
||||||
|
|
||||||
|
def get_param_groups(self):
|
||||||
|
param_groups = []
|
||||||
|
base_param_group = []
|
||||||
|
for k, v in self.named_parameters():
|
||||||
|
if "progressive_switches" not in k and v.requires_grad:
|
||||||
|
base_param_group.append(v)
|
||||||
|
param_groups.append({'params': base_param_group})
|
||||||
|
for i, sw in enumerate(self.progressive_switches):
|
||||||
|
sw_param_group = []
|
||||||
|
for k, v in sw.named_parameters():
|
||||||
|
if v.requires_grad:
|
||||||
|
sw_param_group.append(v)
|
||||||
|
param_groups.append({'params': sw_param_group})
|
||||||
|
return param_groups
|
||||||
|
|
||||||
|
def get_progressive_starts(self):
|
||||||
|
# The base param group starts at step 0, the rest are defined via progressive_switches.
|
||||||
|
return [0] + self.progressive_schedule
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.initial_conv(x)
|
x = self.initial_conv(x)
|
||||||
|
|
||||||
self.attentions = []
|
self.attentions = []
|
||||||
for i, sw in enumerate(self.switches):
|
self.fades = []
|
||||||
fade_in = 1
|
self.enabled_switches = 0
|
||||||
|
for i, sw in enumerate(self.progressive_switches):
|
||||||
|
fade_in = 1 if self.progressive_schedule[i] == 0 else 0
|
||||||
if self.latest_step > 0 and self.progressive_schedule[i] != 0:
|
if self.latest_step > 0 and self.progressive_schedule[i] != 0:
|
||||||
switch_age = self.latest_step - self.progressive_schedule[i]
|
switch_age = self.latest_step - self.progressive_schedule[i]
|
||||||
fade_in = min(1, switch_age * self.growth_fade_in_per_step)
|
fade_in = min(1, switch_age * self.growth_fade_in_per_step)
|
||||||
|
|
||||||
x, att = sw.forward(x, True, fixed_scale=fade_in)
|
if fade_in > 0:
|
||||||
self.attentions.append(att)
|
self.enabled_switches += 1
|
||||||
|
x, att = sw.forward(x, True, fixed_scale=fade_in)
|
||||||
|
self.attentions.append(att)
|
||||||
|
self.fades.append(fade_in)
|
||||||
|
|
||||||
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
|
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
|
||||||
if self.upsample_factor > 2:
|
if self.upsample_factor > 2:
|
||||||
|
@ -81,18 +105,13 @@ class GrowingSRGBase(nn.Module):
|
||||||
return x, x
|
return x, x
|
||||||
|
|
||||||
def update_for_step(self, step, experiments_path='.'):
|
def update_for_step(self, step, experiments_path='.'):
|
||||||
self.latest_step = step
|
self.latest_step = step + self.start_step
|
||||||
|
|
||||||
# Add any new layers as spelled out by the schedule.
|
|
||||||
if step != 0:
|
|
||||||
for i, s in enumerate(self.progressive_schedule):
|
|
||||||
if s == step:
|
|
||||||
self.add_layer(i + 1)
|
|
||||||
|
|
||||||
# Set the temperature of the switches, per-layer.
|
# Set the temperature of the switches, per-layer.
|
||||||
for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.switches)):
|
for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.progressive_switches)):
|
||||||
temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
|
temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
|
||||||
sw.set_temperature(self.init_temperature - temp_loss_per_step * (step - first_step))
|
sw.set_temperature(min(self.init_temperature,
|
||||||
|
max(self.init_temperature - temp_loss_per_step * (step - first_step), 1)))
|
||||||
|
|
||||||
# Save attention images.
|
# Save attention images.
|
||||||
if self.attentions is not None and step % 50 == 0:
|
if self.attentions is not None and step % 50 == 0:
|
||||||
|
@ -106,10 +125,12 @@ class GrowingSRGBase(nn.Module):
|
||||||
for i in range(len(means)):
|
for i in range(len(means)):
|
||||||
val["switch_%i_specificity" % (i,)] = means[i]
|
val["switch_%i_specificity" % (i,)] = means[i]
|
||||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||||
val["switch_%i_temperature" % (i,)] = self.switches[i].switch.temperature
|
val["switch_%i_temperature" % (i,)] = self.progressive_switches[i].switch.temperature
|
||||||
|
for i, f in enumerate(self.fades):
|
||||||
|
val["switch_%i_fade" % (i,)] = f
|
||||||
|
val["enabled_switches"] = self.enabled_switches
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorDownsample(nn.Module):
|
class DiscriminatorDownsample(nn.Module):
|
||||||
def __init__(self, base_filters, end_filters):
|
def __init__(self, base_filters, end_filters):
|
||||||
self.conv0 = ConvGnLelu(base_filters, end_filters, kernel_size=3, bias=False)
|
self.conv0 = ConvGnLelu(base_filters, end_filters, kernel_size=3, bias=False)
|
||||||
|
|
|
@ -296,7 +296,7 @@ class ConvBnLelu(nn.Module):
|
||||||
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvGnLelu(nn.Module):
|
class ConvGnLelu(nn.Module):
|
||||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8):
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
|
||||||
super(ConvGnLelu, self).__init__()
|
super(ConvGnLelu, self).__init__()
|
||||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
assert kernel_size in padding_map.keys()
|
assert kernel_size in padding_map.keys()
|
||||||
|
@ -315,6 +315,9 @@ class ConvGnLelu(nn.Module):
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
||||||
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
||||||
|
m.weight.data *= weight_init_factor
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.weight, 1)
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
|
@ -5,6 +5,26 @@ import torch
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
|
||||||
|
|
||||||
|
# This scheduler is specifically designed to modulate the learning rate of several different param groups configured
|
||||||
|
# by a generator or discriminator that slowly adds new stages one at a time, e.g. like progressive growing of GANs.
|
||||||
|
class ProgressiveMultiStepLR(_LRScheduler):
|
||||||
|
def __init__(self, optimizer, milestones, group_starts, gamma=0.1):
|
||||||
|
self.milestones = Counter(milestones)
|
||||||
|
self.gamma = gamma
|
||||||
|
self.group_starts = group_starts
|
||||||
|
super(ProgressiveMultiStepLR, self).__init__(optimizer)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
group_lrs = []
|
||||||
|
assert len(self.optimizer.param_groups) == len(self.group_starts)
|
||||||
|
for group, group_start in zip(self.optimizer.param_groups, self.group_starts):
|
||||||
|
if self.last_epoch - group_start not in self.milestones:
|
||||||
|
group_lrs.append(group['lr'])
|
||||||
|
else:
|
||||||
|
group_lrs.append(group['lr'] * self.gamma)
|
||||||
|
return group_lrs
|
||||||
|
|
||||||
|
|
||||||
class MultiStepLR_Restart(_LRScheduler):
|
class MultiStepLR_Restart(_LRScheduler):
|
||||||
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
|
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
|
||||||
clear_state=False, force_lr=False, last_epoch=-1):
|
clear_state=False, force_lr=False, last_epoch=-1):
|
||||||
|
|
|
@ -10,6 +10,7 @@ import models.archs.NestedSwitchGenerator as ng
|
||||||
import models.archs.feature_arch as feature_arch
|
import models.archs.feature_arch as feature_arch
|
||||||
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||||
import models.archs.SRG1_arch as srg1
|
import models.archs.SRG1_arch as srg1
|
||||||
|
import models.archs.ProgressiveSrg_arch as psrg
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
# Generator
|
# Generator
|
||||||
|
@ -78,15 +79,14 @@ def define_G(opt, net_key='network_G'):
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||||
elif which_model == "DualOutputSRG":
|
elif which_model == "ProgressiveSRG2":
|
||||||
netG = SwitchedGen_arch.DualOutputSRG(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
netG = psrg.GrowingSRGBase(progressive_step_schedule=opt_net['schedule'], switch_reductions=opt_net['reductions'],
|
||||||
switch_reductions=opt_net['switch_reductions'],
|
growth_fade_in_steps=opt_net['fade_in_steps'], switch_filters=opt_net['switch_filters'],
|
||||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
trans_layers=opt_net['trans_layers'], transformation_filters=opt_net['transformation_filters'],
|
||||||
transformation_filters=opt_net['transformation_filters'],
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
start_step=opt_net['start_step'])
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
|
||||||
|
|
||||||
# image corruption
|
# image corruption
|
||||||
elif which_model == 'HighToLowResNet':
|
elif which_model == 'HighToLowResNet':
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_imgset_pixgan_srg2/train_imgset_pixgan_srg2.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_progressive_srg2.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
@ -161,7 +161,7 @@ def main():
|
||||||
current_step = resume_state['iter']
|
current_step = resume_state['iter']
|
||||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||||
else:
|
else:
|
||||||
current_step = -1
|
current_step = 0
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
|
|
||||||
#### training
|
#### training
|
||||||
|
|
Loading…
Reference in New Issue
Block a user