Migrate generators to dynamic model registration

This commit is contained in:
James Betker 2020-12-24 22:50:14 -07:00
parent a947f064cc
commit 10fdfa1563
21 changed files with 243 additions and 551 deletions

View File

@ -3,6 +3,7 @@
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/codes/models/flownet2" vcs="Git" />
<mapping directory="$PROJECT_DIR$/codes/models/switched_conv" vcs="Git" />
<mapping directory="$PROJECT_DIR$/codes/switched_conv" vcs="Git" />
</component>
</project>

View File

@ -1,275 +0,0 @@
import models.SwitchedResidualGenerator_arch as srg
import torch
import torch.nn as nn
from switched_conv.switched_conv_util import save_attention_to_image
from switched_conv.switched_conv import compute_attention_specificity
from models.arch_util import ConvGnLelu, ExpansionBlock, MultiConvBlock
import functools
import torch.nn.functional as F
# Some notes about this new architecture:
# 1) Discriminator is going to need to get update_for_step() called.
# 2) Not sure if pixgan part of discriminator is going to work properly, make sure to test at multiple add levels.
# 3) Also not sure if growth modules will be properly saved/trained, be sure to test this.
# 4) start_step will need to get set properly when constructing these models, even when resuming - OR another method needs to be added to resume properly.
class GrowingSRGBase(nn.Module):
def __init__(self, progressive_step_schedule, switch_reductions, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, upsample_factor=1,
add_scalable_noise_to_transforms=False, start_step=0):
super(GrowingSRGBase, self).__init__()
switches = []
self.initial_conv = ConvGnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
self.upconv1 = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.upconv2 = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.hr_conv = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.final_conv = ConvGnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
self.switch_filters = switch_filters
self.switch_processing_layers = switch_processing_layers
self.trans_layers = trans_layers
self.transformation_filters = transformation_filters
self.progressive_schedule = progressive_step_schedule
self.switch_reductions = switch_reductions # This lists the reductions for all switches (even ones not activated yet).
self.growth_fade_in_per_step = 1 / growth_fade_in_steps
self.transformation_counts = trans_counts
self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step
self.attentions = None
self.upsample_factor = upsample_factor
self.add_noise_to_transform = add_scalable_noise_to_transforms
self.start_step = start_step
self.latest_step = start_step
self.fades = []
self.counter = 0
assert self.upsample_factor == 2 or self.upsample_factor == 4
switches = []
for i, (step, reductions) in enumerate(zip(progressive_step_schedule, switch_reductions)):
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
reductions, self.switch_processing_layers, self.transformation_counts)
pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
bias=False, weight_init_factor=.1)
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
weight_init_factor=.1)
switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn,
transform_block=transform_fn,
transform_count=self.transformation_counts, init_temp=self.init_temperature,
add_scalable_noise_to_transforms=self.add_noise_to_transform,
attention_norm=False))
self.progressive_switches = nn.ModuleList(switches)
def get_param_groups(self):
param_groups = []
base_param_group = []
for k, v in self.named_parameters():
if "progressive_switches" not in k and v.requires_grad:
base_param_group.append(v)
param_groups.append({'params': base_param_group})
for i, sw in enumerate(self.progressive_switches):
sw_param_group = []
for k, v in sw.named_parameters():
if v.requires_grad:
sw_param_group.append(v)
param_groups.append({'params': sw_param_group})
return param_groups
# This is a hacky way of modifying the underlying model while training. Since changing the model means changing
# the optimizer and the scheduler, these things are fed in. For ProgressiveSrg, this function adds an additional
# switch to the end of the chain with depth=3 and an online time set at the end fo the function.
def update_model(self, opt, sched):
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
3, self.switch_processing_layers, self.transformation_counts)
pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
bias=False, weight_init_factor=.1)
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
weight_init_factor=.1)
new_sw = srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn,
transform_block=transform_fn,
transform_count=self.transformation_counts, init_temp=self.init_temperature,
add_scalable_noise_to_transforms=self.add_noise_to_transform,
attention_norm=False).to('cuda')
self.progressive_switches.append(new_sw)
new_sw_param_group = []
for k, v in new_sw.named_parameters():
if v.requires_grad:
new_sw_param_group.append(v)
opt.add_param_group({'params': new_sw_param_group})
self.progressive_schedule.append(150000)
sched.group_starts.append(150000)
def get_progressive_starts(self):
# The base param group starts at step 0, the rest are defined via progressive_switches.
return [0] + self.progressive_schedule
# This method turns requires_grad on and off for different switches, allowing very large models to be trained while
# using less memory. When used in conjunction with gradient accumulation, it becomes a form of model parallelism.
# <groups> controls the proportion of switches that are enabled. 1/groups will be enabled.
# Switches that are younger than 40000 steps are not eligible to be turned off.
def do_switched_grad(self, groups=1):
# If requires_grad is already disabled, don't bother.
if not self.initial_conv.conv.weight.requires_grad or groups == 1:
return
self.counter = (self.counter + 1) % groups
enabled = []
for i, sw in enumerate(self.progressive_switches):
if self.latest_step - self.progressive_schedule[i] > 40000 and i % groups != self.counter:
for p in sw.parameters():
p.requires_grad = False
else:
enabled.append(i)
for p in sw.parameters():
p.requires_grad = True
def forward(self, x):
self.do_switched_grad(2)
x = self.initial_conv(x)
self.attentions = []
self.fades = []
self.enabled_switches = 0
for i, sw in enumerate(self.progressive_switches):
fade_in = 1 if self.progressive_schedule[i] == 0 else 0
if self.latest_step > 0 and self.progressive_schedule[i] != 0:
switch_age = self.latest_step - self.progressive_schedule[i]
fade_in = min(1, switch_age * self.growth_fade_in_per_step)
if fade_in > 0:
self.enabled_switches += 1
x, att = sw.forward(x, True, fixed_scale=fade_in)
self.attentions.append(att)
self.fades.append(fade_in)
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
if self.upsample_factor > 2:
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = self.upconv2(x)
x = self.final_conv(self.hr_conv(x))
return x, x
def update_for_step(self, step, experiments_path='.'):
self.latest_step = step + self.start_step
# Set the temperature of the switches, per-layer.
for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.progressive_switches)):
temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
sw.set_temperature(min(self.init_temperature,
max(self.init_temperature - temp_loss_per_step * (step - first_step), 1)))
# Save attention images.
if self.attentions is not None and step % 50 == 0:
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))]
def get_debug_values(self, step):
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
val["switch_%i_temperature" % (i,)] = self.progressive_switches[i].switch.temperature
for i, f in enumerate(self.fades):
val["switch_%i_fade" % (i,)] = f
val["enabled_switches"] = self.enabled_switches
return val
class DiscriminatorDownsample(nn.Module):
def __init__(self, base_filters, end_filters):
self.conv0 = ConvGnLelu(base_filters, end_filters, kernel_size=3, bias=False)
self.conv1 = ConvGnLelu(end_filters, end_filters, kernel_size=3, stride=2, bias=False)
def forward(self, x):
return self.conv1(self.conv0(x))
class DiscriminatorUpsample(nn.Module):
def __init__(self, base_filters, end_filters):
self.up = ExpansionBlock(base_filters, end_filters, block=ConvGnLelu)
self.proc = ConvGnLelu(end_filters, end_filters, bias=False)
self.collapse = ConvGnLelu(end_filters, 1, bias=True, norm=False, activation=False)
def forward(self, x, ff):
x = self.up1(x, ff)
return x, self.collapse1(self.proc1(x))
class GrowingUnetDiscBase(nn.Module):
def __init__(self, nf, growing_schedule, growth_fade_in_steps, start_step=0):
super(GrowingUnetDiscBase, self).__init__()
# [64, 128, 128]
self.conv0_0 = ConvGnLelu(3, nf, kernel_size=3, bias=True, activation=False)
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
# [64, 64, 64]
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
self.down_base = DiscriminatorDownsample(nf * 2, nf * 4)
self.up_base = DiscriminatorUpsample(nf * 4, nf * 2)
self.progressive_schedule = growing_schedule
self.growth_fade_in_per_step = 1 / growth_fade_in_steps
self.pnf = nf * 4
self.downsamples = nn.ModuleList([])
self.upsamples = nn.ModuleList([])
for i, step in enumerate(growing_schedule):
if step >= start_step:
self.add_layer(i + 1)
def add_layer(self):
self.downsamples.append(DiscriminatorDownsample(self.pnf, self.pnf))
self.upsamples.append(DiscriminatorUpsample(self.pnf, self.pnf))
def update_for_step(self, step):
self.latest_step = step
# Add any new layers as spelled out by the schedule.
if step != 0:
for i, s in enumerate(self.progressive_schedule):
if s == step:
self.add_layer(i + 1)
def forward(self, x, output_feature_vector=False):
x = self.conv0_0(x)
x = self.conv0_1(x)
x = self.conv1_0(x)
x = self.conv1_1(x)
base_fea = self.down_base(x)
x = base_fea
skips = []
for down in self.downsamples:
x = down(x)
skips.append(x)
losses = []
for i, up in enumerate(self.upsamples):
j = i + 1
x, loss = up(x, skips[-j])
losses.append(loss)
# This variant averages the outputs of the U-net across the upsamples, weighting the contribution
# to the average less for newly growing levels.
_, base_loss = self.up_base(x, base_fea)
res = base_loss.shape[2:]
mean_weight = 1
for i, l in enumerate(losses):
fade_in = 1
if self.latest_step > 0 and self.progressive_schedule[i] != 0:
disc_age = self.latest_step - self.progressive_schedule[i]
fade_in = min(1, disc_age * self.growth_fade_in_per_step)
mean_weight += fade_in
base_loss += F.interpolate(l, size=res, mode="bilinear", align_corners=False) * fade_in
base_loss /= mean_weight
return base_loss.view(-1, 1)
def pixgan_parameters(self):
return 1, 4

View File

@ -6,6 +6,7 @@ import torch.nn.functional as F
import torchvision
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
from trainer.networks import register_model
from utils.util import checkpoint, sequential_checkpoint
@ -370,3 +371,27 @@ class RRDBDiscriminator(nn.Module):
if self.pred_ is not None:
self.pred_ = F.sigmoid(self.pred_)
torchvision.utils.save_image(self.pred_.cpu().float(), os.path.join(path, "%i_predictions.png" % (step,)))
@register_model
def register_RRDBNetBypass(opt_net, opt):
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
output_mode=output_mode, body_block=RRDBWithBypass, scale=opt_net['scale'], growth_channels=gc,
initial_stride=initial_stride)
@register_model
def register_RRDBNet(opt_net, opt):
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
output_mode=output_mode, body_block=RRDB, scale=opt_net['scale'], growth_channels=gc,
initial_stride=initial_stride)

View File

@ -1,17 +1,19 @@
import torch
from torch import nn
from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
import torch.nn.functional as F
import functools
from collections import OrderedDict
from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \
SiLU, UpconvBlock, ReferenceJoinBlock
from models.switched_conv.switched_conv_util import save_attention_to_image_rgb
import os
from models.spinenet_arch import SpineNet
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, MultiConvBlock
from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
from models.switched_conv.switched_conv_util import save_attention_to_image_rgb
from trainer.networks import register_model
from utils.util import checkpoint
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
# Doubles the input filter count.
class HalvingProcessingBlock(nn.Module):
@ -261,142 +263,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
return val
# This class encapsulates an encoder based on an object detection network backbone whose purpose is to generated a
# structured embedding encoding what is in an image patch. This embedding can then be used to perform structured
# alterations to the underlying image.
#
# Caveat: Since this uses a pre-defined (and potentially pre-trained) SpineNet backbone, it has a minimum-supported
# image size, which is 128x128. In order to use 64x64 patches, you must set interpolate_first=True. though this will
# degrade quality.
class BackboneEncoder(nn.Module):
def __init__(self, interpolate_first=True, pretrained_backbone=None):
super(BackboneEncoder, self).__init__()
self.interpolate_first = interpolate_first
# Uses dual spinenets, one for the input patch and the other for the reference image.
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True)
self.ref_spine = SpineNet('49', in_channels=3, use_input_norm=True)
self.merge_process1 = ConvGnSilu(512, 512, kernel_size=1, activation=True, norm=False, bias=True)
self.merge_process2 = ConvGnSilu(512, 384, kernel_size=1, activation=True, norm=True, bias=False)
self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True)
if pretrained_backbone is not None:
loaded_params = torch.load(pretrained_backbone)
self.ref_spine.load_state_dict(loaded_params['state_dict'], strict=True)
self.patch_spine.load_state_dict(loaded_params['state_dict'], strict=True)
# Returned embedding will have been reduced in size by a factor of 8 (4 if interpolate_first=True).
# Output channels are always 256.
# ex, 64x64 input with interpolate_first=True will result in tensor of shape [bx256x16x16]
def forward(self, x, ref, ref_center_point):
if self.interpolate_first:
x = F.interpolate(x, scale_factor=2, mode="bicubic")
# Don't interpolate ref - assume it is fed in at the proper resolution.
# ref = F.interpolate(ref, scale_factor=2, mode="bicubic")
# [ref] will have a 'mask' channel which we cannot use with pretrained spinenet.
ref = ref[:, :3, :, :]
ref_emb = self.ref_spine(ref)[0]
ref_code = gather_2d(ref_emb, ref_center_point // 8) # Divide by 8 to bring the center point to the correct location.
patch = self.patch_spine(x)[0]
ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3])
combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1))
combined = self.merge_process2(combined)
combined = self.merge_process3(combined)
return combined
class BackboneEncoderNoRef(nn.Module):
def __init__(self, interpolate_first=True, pretrained_backbone=None):
super(BackboneEncoderNoRef, self).__init__()
self.interpolate_first = interpolate_first
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True)
if pretrained_backbone is not None:
loaded_params = torch.load(pretrained_backbone)
self.patch_spine.load_state_dict(loaded_params['state_dict'], strict=True)
# Returned embedding will have been reduced in size by a factor of 8 (4 if interpolate_first=True).
# Output channels are always 256.
# ex, 64x64 input with interpolate_first=True will result in tensor of shape [bx256x16x16]
def forward(self, x):
if self.interpolate_first:
x = F.interpolate(x, scale_factor=2, mode="bicubic")
patch = self.patch_spine(x)[0]
return patch
class BackboneSpinenetNoHead(nn.Module):
def __init__(self):
super(BackboneSpinenetNoHead, self).__init__()
# Uses dual spinenets, one for the input patch and the other for the reference image.
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=False, double_reduce_early=False)
self.ref_spine = SpineNet('49', in_channels=4, use_input_norm=False, double_reduce_early=False)
self.merge_process1 = ConvGnSilu(512, 512, kernel_size=1, activation=True, norm=False, bias=True)
self.merge_process2 = ConvGnSilu(512, 384, kernel_size=1, activation=True, norm=True, bias=False)
self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True)
def forward(self, x, ref, ref_center_point):
ref_emb = self.ref_spine(ref)[0]
ref_code = gather_2d(ref_emb, ref_center_point // 4) # Divide by 8 to bring the center point to the correct location.
patch = self.patch_spine(x)[0]
ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3])
combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1))
combined = self.merge_process2(combined)
combined = self.merge_process3(combined)
return combined
class ResBlock(nn.Module):
def __init__(self, nf, downsample):
super(ResBlock, self).__init__()
nf_int = nf * 2
nf_out = nf * 2 if downsample else nf
stride = 2 if downsample else 1
self.c1 = ConvGnSilu(nf, nf_int, kernel_size=3, bias=False, activation=True, norm=True)
self.c2 = ConvGnSilu(nf_int, nf_int, stride=stride, kernel_size=3, bias=False, activation=True, norm=True)
self.c3 = ConvGnSilu(nf_int, nf_out, kernel_size=3, bias=False, activation=False, norm=True)
if downsample:
self.downsample = ConvGnSilu(nf, nf_out, kernel_size=1, stride=stride, bias=False, activation=False, norm=True)
else:
self.downsample = None
self.act = SiLU()
def forward(self, x):
identity = x
branch = self.c1(x)
branch = self.c2(branch)
branch = self.c3(branch)
if self.downsample:
identity = self.downsample(identity)
return self.act(identity + branch)
class BackboneResnet(nn.Module):
def __init__(self):
super(BackboneResnet, self).__init__()
self.initial_conv = ConvGnSilu(3, 64, kernel_size=7, bias=True, activation=False, norm=False)
self.sequence = nn.Sequential(
ResBlock(64, downsample=False),
ResBlock(64, downsample=True),
ResBlock(128, downsample=False),
ResBlock(128, downsample=True),
ResBlock(256, downsample=False),
ResBlock(256, downsample=False))
def forward(self, x):
fea = self.initial_conv(x)
return self.sequence(fea)
# Computes a linear latent by performing processing on the reference image and returning the filters of a single point,
# which should be centered on the image patch being processed.
#
@ -706,3 +572,23 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
@register_model
def register_ConfigurableSwitchedResidualGenerator2(opt_net, opt):
return ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'],
switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net[
'switch_processing_layers'],
trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'],
trans_layers=opt_net['trans_layers'],
transformation_filters=opt_net['transformation_filters'],
attention_norm=opt_net['attention_norm'],
initial_temp=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'],
heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale,
add_scalable_noise_to_transforms=opt_net['add_noise'],
for_video=opt_net['for_video'])

View File

@ -10,7 +10,8 @@ from kornia import filters
from torch import nn
from data.byol_attachment import RandomApply
from utils.util import checkpoint
from trainer.networks import register_model, create_model
from utils.util import checkpoint, opt_get
def default(val, def_val):
@ -269,3 +270,11 @@ class BYOL(nn.Module):
loss = loss_one + loss_two
return loss.mean()
@register_model
def register_byol(opt_net, opt):
subnet = create_model(opt, opt_net['subnet'])
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))

View File

@ -7,6 +7,7 @@ from torch import nn
from data.byol_attachment import reconstructed_shared_regions
from models.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
update_moving_average
from trainer.networks import create_model, register_model
from utils.util import checkpoint
# loss function
@ -179,3 +180,10 @@ class StructuralBYOL(nn.Module):
enc = self.online_encoder(image)
proj = self.online_predictor(enc)
return enc, proj
@register_model
def register_structural_byol(opt_net, opt):
subnet = create_model(opt, opt_net['subnet'])
return StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
freeze_until=opt_get(opt_net, ['freeze_until'], 0))

View File

@ -1,10 +1,8 @@
import torch
import torch.nn as nn
from models.RRDBNet_arch import RRDB, RRDBWithBypass
from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
import torch.nn.functional as F
from models.SwitchedResidualGenerator_arch import gather_2d
from utils.util import checkpoint
@ -519,6 +517,7 @@ class RefDiscriminatorVgg128(nn.Module):
def forward(self, x, ref, ref_center_point):
ref = self.ref_head(ref)
ref_center_point = ref_center_point // 16
from models.SwitchedResidualGenerator_arch import gather_2d
ref_vector = gather_2d(ref, ref_center_point)
ref_vector = self.ref_linear(ref_vector)

View File

View File

@ -10,6 +10,7 @@ from models.arch_util import ConvGnLelu
# Produces a convolutional feature (`f`) and a reduced feature map with double the filters.
from models.glean.stylegan2_latent_bank import Stylegan2LatentBank
from models.stylegan.stylegan2_rosinality import EqualLinear
from trainer.networks import register_model
from utils.util import checkpoint, sequential_checkpoint
@ -108,3 +109,8 @@ class GleanGenerator(nn.Module):
rrdb_fea, conv_fea, latents = self.encoder(x)
latent_bank_fea = self.latent_bank(conv_fea, latents)
return self.decoder(rrdb_fea, latent_bank_fea)
@register_model
def register_glean(opt_net, opt):
return GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])

View File

@ -11,6 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
from trainer.networks import register_model
from utils.util import checkpoint
model_urls = {
@ -188,3 +189,8 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
@register_model
def register_resnet52(opt_net, opt):
return resnet50(pretrained=opt_net['pretrained'])

View File

@ -7,6 +7,7 @@ from torch.nn.init import kaiming_normal
from torchvision.models.resnet import BasicBlock, Bottleneck
from models.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu
from trainer.networks import register_model
def constant_init(module, val, bias=0):
@ -359,3 +360,13 @@ class SpinenetWithLogits(SpineNet):
def forward(self, x):
fea = super().forward(x)[self.output_to_attach]
return self.tail(fea)
@register_model
def register_spinenet(opt_net, opt):
return SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
@register_model
def register_spinenet_with_logits(opt_net, opt):
return SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
in_channels=3, use_input_norm=opt_net['use_input_norm'])

View File

@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
import models.srflow.module_util as mutil
from models.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
from trainer.networks import register_model
from utils.util import opt_get
@ -240,3 +241,18 @@ class RRDBLatentWrapper(nn.Module):
fea = torch.cat(blocklist, dim=1)
fea = self.postprocess(fea)
return fea
@register_model
def register_rrdb_latent_wrapper(opt_net, opt):
return RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'],
pretrain_rrdb_path=opt_net['pretrain_path'])
@register_model
def register_rrdb_srflow(opt_net, opt):
return RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
initial_conv_stride=opt_net['initial_stride'])

View File

@ -8,6 +8,7 @@ from models.srflow.RRDBNet_arch import RRDBNet
from models.srflow.FlowUpsamplerNet import FlowUpsamplerNet
import models.srflow.thops as thops
import models.srflow.flow as flow
from trainer.networks import register_model
from utils.util import opt_get
@ -166,3 +167,9 @@ class SRFlowNet(nn.Module):
logdet=logdet)
return x, logdet, lr_enc['out']
@register_model
def register_srflow(opt_net, opt):
return SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
K=opt_net['K'], opt=opt)

View File

@ -11,6 +11,7 @@ from collections import OrderedDict
from models.SwitchedResidualGenerator_arch import HalvingProcessingBlock, ConfigurableSwitchComputer
from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, MultiConvBlock
from models.switched_conv.switched_conv import BareConvSwitch, AttentionNorm
from trainer.networks import register_model
from utils.util import checkpoint
@ -204,3 +205,19 @@ class Interpolate(nn.Module):
def forward(self, x):
return F.interpolate(x, scale_factor=self.factor, mode=self.mode)
@register_model
def register_srg2classic(opt_net, opt):
return ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'],
switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'],
trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'],
trans_layers=opt_net['trans_layers'],
transformation_filters=opt_net['transformation_filters'],
initial_temp=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'],
heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale,
add_scalable_noise_to_transforms=opt_net['add_noise'])

View File

@ -17,6 +17,7 @@ from torch import nn
from torch.autograd import grad as torch_grad
from vector_quantize_pytorch import VectorQuantize
from trainer.networks import register_model
from utils.util import checkpoint
try:
@ -893,3 +894,12 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss):
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
return 0
@register_model
def register_stylegan2_lucidrains(opt_net, opt):
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
return StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
style_depth=opt_net['style_depth'], structure_input=is_structured,
attn_layers=attn)

View File

@ -0,0 +1,15 @@
import munch
import torch
from trainer.networks import register_model
@register_model
def register_flownet2(opt_net):
from models.flownet2.models import FlowNet2
ld = 'load_path' in opt_net.keys()
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
netG = FlowNet2(args)
if ld:
sd = torch.load(opt_net['load_path'])
netG.load_state_dict(sd['state_dict'])

View File

@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torchvision
from trainer.networks import register_model
from utils.util import sequential_checkpoint
from models.arch_util import ConvGnSilu, make_layer
@ -71,3 +72,8 @@ class TecoGen(nn.Module):
def get_debug_values(self, step, net_name):
return {'branch_std': self.join.std()}
@register_model
def register_tecogen(opt_net, opt):
return TecoGen(opt_net['nf'], opt_net['scale'])

View File

@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from trainer.injectors import Injector
from trainer.networks import register_model
from utils.util import checkpoint
@ -147,3 +148,8 @@ class iGPT2(nn.Module):
return logits, x
@register_model
def register_igpt2(opt_net, opt):
return iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2,
opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])

View File

@ -57,7 +57,7 @@ class ExtensibleTrainer(BaseModel):
new_net = None
if net['type'] == 'generator':
if new_net is None:
new_net = networks.define_G(opt, net, opt['scale']).to(self.device)
new_net = networks.create_model(opt, net, opt['scale']).to(self.device)
self.netsG[name] = new_net
elif net['type'] == 'discriminator':
if new_net is None:

View File

@ -1,140 +1,79 @@
import functools
import importlib
import logging
import pkgutil
import sys
from collections import OrderedDict
from inspect import isfunction, getmembers
import munch
import torch
import torchvision
from munch import munchify
import models.stylegan.stylegan2_lucidrains as stylegan2
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
import models.RRDBNet_arch as RRDBNet_arch
import models.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.discriminator_vgg_arch as SRGAN_arch
import models.feature_arch as feature_arch
from models import srg2_classic
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
from models.tecogan.teco_resgen import TecoGen
from utils.util import opt_get
logger = logging.getLogger('base')
# Generator
def define_G(opt, opt_net, scale=None):
if scale is None:
scale = opt['scale']
which_model = opt_net['which_model_G']
if 'RRDBNet' in which_model:
if which_model == 'RRDBNetBypass':
block = RRDBNet_arch.RRDBWithBypass
elif which_model == 'RRDBNetLambda':
from models.lambda_rrdb import LambdaRRDB
block = LambdaRRDB
else:
block = RRDBNet_arch.RRDB
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc,
initial_stride=initial_stride)
elif which_model == "ConfigurableSwitchedResidualGenerator2":
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
for_video=opt_net['for_video'])
elif which_model == "srg2classic":
netG = srg2_classic.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
transformation_filters=opt_net['transformation_filters'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "flownet2":
from models.flownet2 import FlowNet2
ld = 'load_path' in opt_net.keys()
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
netG = FlowNet2(args)
if ld:
sd = torch.load(opt_net['load_path'])
netG.load_state_dict(sd['state_dict'])
elif which_model == "backbone_encoder":
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
elif which_model == "backbone_encoder_no_ref":
netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet'])
elif which_model == "backbone_encoder_no_head":
netG = SwitchedGen_arch.BackboneSpinenetNoHead()
elif which_model == "backbone_resnet":
netG = SwitchedGen_arch.BackboneResnet()
elif which_model == "tecogen":
netG = TecoGen(opt_net['nf'], opt_net['scale'])
elif which_model == 'stylegan2':
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
style_depth=opt_net['style_depth'], structure_input=is_structured,
attn_layers=attn)
elif which_model == 'srflow':
from models.srflow import SRFlowNet_arch
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
K=opt_net['K'], opt=opt)
elif which_model == 'rrdb_latent_wrapper':
from models.srflow.RRDBNet_arch import RRDBLatentWrapper
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
elif which_model == 'rrdb_centipede':
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
headless=True, output_mode=output_mode)
elif which_model == 'rrdb_srflow':
from models.srflow.RRDBNet_arch import RRDBNet
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
initial_conv_stride=opt_net['initial_stride'])
elif which_model == 'igpt2':
from models.transformers.igpt.gpt2 import iGPT2
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
elif which_model == 'byol':
from models.byol.byol_model_wrapper import BYOL
subnet = define_G(opt, opt_net['subnet'])
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
elif which_model == 'structural_byol':
from models.byol.byol_structural import StructuralBYOL
subnet = define_G(opt, opt_net['subnet'])
netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
freeze_until=opt_get(opt_net, ['freeze_until'], 0))
elif which_model == 'spinenet':
from models.spinenet_arch import SpineNet
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
elif which_model == 'spinenet_with_logits':
from models.spinenet_arch import SpinenetWithLogits
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
in_channels=3, use_input_norm=opt_net['use_input_norm'])
elif which_model == 'resnet52':
from models.resnet_with_checkpointing import resnet50
netG = resnet50(pretrained=opt_net['pretrained'])
elif which_model == 'glean':
from models.glean.glean import GleanGenerator
netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])
class RegisteredModelNameError(Exception):
def __init__(self, name_error):
super().__init__(f'Registered DLAS modules must start with `register_`. Incorrect registration: {name_error}')
# Decorator that allows API clients to show DLAS how to build a nn.Module from an opt dict.
# Functions with this decorator should have a specific naming format:
# `register_<name>` where <name> is the name that will be used in configuration files to reference this model.
# Functions with this decorator are expected to take a single argument:
# - opt: A dict with the configuration options for building the module.
# They should return:
# - A torch.nn.Module object for the model being defined.
def register_model(func):
if func.__name__.startswith("register_"):
func._dlas_model_name = func.__name__[9:]
assert func._dlas_model_name
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG
raise RegisteredModelNameError(func.__name__)
func._dlas_registered_model = True
return func
def find_registered_model_fns(base_path='models'):
found_fns = {}
module_iter = pkgutil.walk_packages([base_path])
for mod in module_iter:
if mod.ispkg:
EXCLUSION_LIST = ['flownet2']
if mod.name not in EXCLUSION_LIST:
found_fns.update(find_registered_model_fns(f'{base_path}/{mod.name}'))
else:
mod_name = f'{base_path}/{mod.name}'.replace('/', '.')
importlib.import_module(mod_name)
for mod_fn in getmembers(sys.modules[mod_name], isfunction):
if hasattr(mod_fn[1], "_dlas_registered_model"):
found_fns[mod_fn[1]._dlas_model_name] = mod_fn[1]
return found_fns
class CreateModelError(Exception):
def __init__(self, name, available):
super().__init__(f'Could not find the specified model name: {name}. Tip: If your model is in a'
f' subdirectory, that directory must contain an __init__.py to be scanned. Available models:'
f'{available}')
def create_model(opt, opt_net, scale=None):
which_model = opt_net['which_model']
# For backwards compatibility.
if not which_model:
which_model = opt_net['which_model_G']
if not which_model:
which_model = opt_net['which_model_D']
registered_fns = find_registered_model_fns()
if which_model not in registered_fns.keys():
raise CreateModelError(which_model, list(registered_fns.keys()))
return registered_fns[which_model](opt_net, opt)
class GradDiscWrapper(torch.nn.Module):

View File

@ -2,7 +2,7 @@ import argparse
import functools
import torch
from utils import options as option
from trainer.networks import define_G
from trainer.networks import create_model
class TracedModule:
@ -96,7 +96,7 @@ if __name__ == "__main__":
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
netG = define_G(opt)
netG = create_model(opt)
dummyInput = torch.rand(1,3,32,32)
mode = 'onnx'