forked from mrq/DL-Art-School
Migrate generators to dynamic model registration
This commit is contained in:
parent
a947f064cc
commit
10fdfa1563
|
@ -3,6 +3,7 @@
|
|||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/codes/models/flownet2" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/codes/models/switched_conv" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/codes/switched_conv" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
|
@ -1,275 +0,0 @@
|
|||
import models.SwitchedResidualGenerator_arch as srg
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from switched_conv.switched_conv_util import save_attention_to_image
|
||||
from switched_conv.switched_conv import compute_attention_specificity
|
||||
from models.arch_util import ConvGnLelu, ExpansionBlock, MultiConvBlock
|
||||
import functools
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Some notes about this new architecture:
|
||||
# 1) Discriminator is going to need to get update_for_step() called.
|
||||
# 2) Not sure if pixgan part of discriminator is going to work properly, make sure to test at multiple add levels.
|
||||
# 3) Also not sure if growth modules will be properly saved/trained, be sure to test this.
|
||||
# 4) start_step will need to get set properly when constructing these models, even when resuming - OR another method needs to be added to resume properly.
|
||||
|
||||
class GrowingSRGBase(nn.Module):
|
||||
def __init__(self, progressive_step_schedule, switch_reductions, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts,
|
||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, upsample_factor=1,
|
||||
add_scalable_noise_to_transforms=False, start_step=0):
|
||||
super(GrowingSRGBase, self).__init__()
|
||||
switches = []
|
||||
self.initial_conv = ConvGnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
|
||||
self.upconv1 = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||
self.upconv2 = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||
self.hr_conv = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||
self.final_conv = ConvGnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
|
||||
|
||||
self.switch_filters = switch_filters
|
||||
self.switch_processing_layers = switch_processing_layers
|
||||
self.trans_layers = trans_layers
|
||||
self.transformation_filters = transformation_filters
|
||||
self.progressive_schedule = progressive_step_schedule
|
||||
self.switch_reductions = switch_reductions # This lists the reductions for all switches (even ones not activated yet).
|
||||
self.growth_fade_in_per_step = 1 / growth_fade_in_steps
|
||||
self.transformation_counts = trans_counts
|
||||
self.init_temperature = initial_temp
|
||||
self.final_temperature_step = final_temperature_step
|
||||
self.attentions = None
|
||||
self.upsample_factor = upsample_factor
|
||||
self.add_noise_to_transform = add_scalable_noise_to_transforms
|
||||
self.start_step = start_step
|
||||
self.latest_step = start_step
|
||||
self.fades = []
|
||||
self.counter = 0
|
||||
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
||||
|
||||
switches = []
|
||||
for i, (step, reductions) in enumerate(zip(progressive_step_schedule, switch_reductions)):
|
||||
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
|
||||
reductions, self.switch_processing_layers, self.transformation_counts)
|
||||
pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
|
||||
bias=False, weight_init_factor=.1)
|
||||
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
|
||||
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
|
||||
weight_init_factor=.1)
|
||||
switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn,
|
||||
transform_block=transform_fn,
|
||||
transform_count=self.transformation_counts, init_temp=self.init_temperature,
|
||||
add_scalable_noise_to_transforms=self.add_noise_to_transform,
|
||||
attention_norm=False))
|
||||
self.progressive_switches = nn.ModuleList(switches)
|
||||
|
||||
def get_param_groups(self):
|
||||
param_groups = []
|
||||
base_param_group = []
|
||||
for k, v in self.named_parameters():
|
||||
if "progressive_switches" not in k and v.requires_grad:
|
||||
base_param_group.append(v)
|
||||
param_groups.append({'params': base_param_group})
|
||||
for i, sw in enumerate(self.progressive_switches):
|
||||
sw_param_group = []
|
||||
for k, v in sw.named_parameters():
|
||||
if v.requires_grad:
|
||||
sw_param_group.append(v)
|
||||
param_groups.append({'params': sw_param_group})
|
||||
return param_groups
|
||||
|
||||
# This is a hacky way of modifying the underlying model while training. Since changing the model means changing
|
||||
# the optimizer and the scheduler, these things are fed in. For ProgressiveSrg, this function adds an additional
|
||||
# switch to the end of the chain with depth=3 and an online time set at the end fo the function.
|
||||
def update_model(self, opt, sched):
|
||||
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
|
||||
3, self.switch_processing_layers, self.transformation_counts)
|
||||
pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
|
||||
bias=False, weight_init_factor=.1)
|
||||
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
|
||||
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
|
||||
weight_init_factor=.1)
|
||||
new_sw = srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn,
|
||||
transform_block=transform_fn,
|
||||
transform_count=self.transformation_counts, init_temp=self.init_temperature,
|
||||
add_scalable_noise_to_transforms=self.add_noise_to_transform,
|
||||
attention_norm=False).to('cuda')
|
||||
self.progressive_switches.append(new_sw)
|
||||
new_sw_param_group = []
|
||||
for k, v in new_sw.named_parameters():
|
||||
if v.requires_grad:
|
||||
new_sw_param_group.append(v)
|
||||
opt.add_param_group({'params': new_sw_param_group})
|
||||
self.progressive_schedule.append(150000)
|
||||
sched.group_starts.append(150000)
|
||||
|
||||
def get_progressive_starts(self):
|
||||
# The base param group starts at step 0, the rest are defined via progressive_switches.
|
||||
return [0] + self.progressive_schedule
|
||||
|
||||
# This method turns requires_grad on and off for different switches, allowing very large models to be trained while
|
||||
# using less memory. When used in conjunction with gradient accumulation, it becomes a form of model parallelism.
|
||||
# <groups> controls the proportion of switches that are enabled. 1/groups will be enabled.
|
||||
# Switches that are younger than 40000 steps are not eligible to be turned off.
|
||||
def do_switched_grad(self, groups=1):
|
||||
# If requires_grad is already disabled, don't bother.
|
||||
if not self.initial_conv.conv.weight.requires_grad or groups == 1:
|
||||
return
|
||||
self.counter = (self.counter + 1) % groups
|
||||
enabled = []
|
||||
for i, sw in enumerate(self.progressive_switches):
|
||||
if self.latest_step - self.progressive_schedule[i] > 40000 and i % groups != self.counter:
|
||||
for p in sw.parameters():
|
||||
p.requires_grad = False
|
||||
else:
|
||||
enabled.append(i)
|
||||
for p in sw.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
def forward(self, x):
|
||||
self.do_switched_grad(2)
|
||||
|
||||
x = self.initial_conv(x)
|
||||
|
||||
self.attentions = []
|
||||
self.fades = []
|
||||
self.enabled_switches = 0
|
||||
for i, sw in enumerate(self.progressive_switches):
|
||||
fade_in = 1 if self.progressive_schedule[i] == 0 else 0
|
||||
if self.latest_step > 0 and self.progressive_schedule[i] != 0:
|
||||
switch_age = self.latest_step - self.progressive_schedule[i]
|
||||
fade_in = min(1, switch_age * self.growth_fade_in_per_step)
|
||||
|
||||
if fade_in > 0:
|
||||
self.enabled_switches += 1
|
||||
x, att = sw.forward(x, True, fixed_scale=fade_in)
|
||||
self.attentions.append(att)
|
||||
self.fades.append(fade_in)
|
||||
|
||||
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
|
||||
if self.upsample_factor > 2:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
x = self.upconv2(x)
|
||||
x = self.final_conv(self.hr_conv(x))
|
||||
return x, x
|
||||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
self.latest_step = step + self.start_step
|
||||
|
||||
# Set the temperature of the switches, per-layer.
|
||||
for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.progressive_switches)):
|
||||
temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
|
||||
sw.set_temperature(min(self.init_temperature,
|
||||
max(self.init_temperature - temp_loss_per_step * (step - first_step), 1)))
|
||||
|
||||
# Save attention images.
|
||||
if self.attentions is not None and step % 50 == 0:
|
||||
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))]
|
||||
|
||||
def get_debug_values(self, step):
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||
val = {}
|
||||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
val["switch_%i_temperature" % (i,)] = self.progressive_switches[i].switch.temperature
|
||||
for i, f in enumerate(self.fades):
|
||||
val["switch_%i_fade" % (i,)] = f
|
||||
val["enabled_switches"] = self.enabled_switches
|
||||
return val
|
||||
|
||||
class DiscriminatorDownsample(nn.Module):
|
||||
def __init__(self, base_filters, end_filters):
|
||||
self.conv0 = ConvGnLelu(base_filters, end_filters, kernel_size=3, bias=False)
|
||||
self.conv1 = ConvGnLelu(end_filters, end_filters, kernel_size=3, stride=2, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv1(self.conv0(x))
|
||||
|
||||
|
||||
class DiscriminatorUpsample(nn.Module):
|
||||
def __init__(self, base_filters, end_filters):
|
||||
self.up = ExpansionBlock(base_filters, end_filters, block=ConvGnLelu)
|
||||
self.proc = ConvGnLelu(end_filters, end_filters, bias=False)
|
||||
self.collapse = ConvGnLelu(end_filters, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
def forward(self, x, ff):
|
||||
x = self.up1(x, ff)
|
||||
return x, self.collapse1(self.proc1(x))
|
||||
|
||||
|
||||
class GrowingUnetDiscBase(nn.Module):
|
||||
def __init__(self, nf, growing_schedule, growth_fade_in_steps, start_step=0):
|
||||
super(GrowingUnetDiscBase, self).__init__()
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = ConvGnLelu(3, nf, kernel_size=3, bias=True, activation=False)
|
||||
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
|
||||
# [64, 64, 64]
|
||||
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
|
||||
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
|
||||
|
||||
self.down_base = DiscriminatorDownsample(nf * 2, nf * 4)
|
||||
self.up_base = DiscriminatorUpsample(nf * 4, nf * 2)
|
||||
|
||||
self.progressive_schedule = growing_schedule
|
||||
self.growth_fade_in_per_step = 1 / growth_fade_in_steps
|
||||
self.pnf = nf * 4
|
||||
self.downsamples = nn.ModuleList([])
|
||||
self.upsamples = nn.ModuleList([])
|
||||
|
||||
for i, step in enumerate(growing_schedule):
|
||||
if step >= start_step:
|
||||
self.add_layer(i + 1)
|
||||
|
||||
def add_layer(self):
|
||||
self.downsamples.append(DiscriminatorDownsample(self.pnf, self.pnf))
|
||||
self.upsamples.append(DiscriminatorUpsample(self.pnf, self.pnf))
|
||||
|
||||
def update_for_step(self, step):
|
||||
self.latest_step = step
|
||||
|
||||
# Add any new layers as spelled out by the schedule.
|
||||
if step != 0:
|
||||
for i, s in enumerate(self.progressive_schedule):
|
||||
if s == step:
|
||||
self.add_layer(i + 1)
|
||||
|
||||
def forward(self, x, output_feature_vector=False):
|
||||
x = self.conv0_0(x)
|
||||
x = self.conv0_1(x)
|
||||
x = self.conv1_0(x)
|
||||
x = self.conv1_1(x)
|
||||
base_fea = self.down_base(x)
|
||||
x = base_fea
|
||||
|
||||
skips = []
|
||||
for down in self.downsamples:
|
||||
x = down(x)
|
||||
skips.append(x)
|
||||
|
||||
losses = []
|
||||
for i, up in enumerate(self.upsamples):
|
||||
j = i + 1
|
||||
x, loss = up(x, skips[-j])
|
||||
losses.append(loss)
|
||||
|
||||
# This variant averages the outputs of the U-net across the upsamples, weighting the contribution
|
||||
# to the average less for newly growing levels.
|
||||
_, base_loss = self.up_base(x, base_fea)
|
||||
res = base_loss.shape[2:]
|
||||
|
||||
mean_weight = 1
|
||||
for i, l in enumerate(losses):
|
||||
fade_in = 1
|
||||
if self.latest_step > 0 and self.progressive_schedule[i] != 0:
|
||||
disc_age = self.latest_step - self.progressive_schedule[i]
|
||||
fade_in = min(1, disc_age * self.growth_fade_in_per_step)
|
||||
mean_weight += fade_in
|
||||
base_loss += F.interpolate(l, size=res, mode="bilinear", align_corners=False) * fade_in
|
||||
base_loss /= mean_weight
|
||||
|
||||
return base_loss.view(-1, 1)
|
||||
|
||||
def pixgan_parameters(self):
|
||||
return 1, 4
|
|
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
|||
import torchvision
|
||||
|
||||
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, sequential_checkpoint
|
||||
|
||||
|
||||
|
@ -370,3 +371,27 @@ class RRDBDiscriminator(nn.Module):
|
|||
if self.pred_ is not None:
|
||||
self.pred_ = F.sigmoid(self.pred_)
|
||||
torchvision.utils.save_image(self.pred_.cpu().float(), os.path.join(path, "%i_predictions.png" % (step,)))
|
||||
|
||||
|
||||
@register_model
|
||||
def register_RRDBNetBypass(opt_net, opt):
|
||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
|
||||
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
|
||||
return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
||||
output_mode=output_mode, body_block=RRDBWithBypass, scale=opt_net['scale'], growth_channels=gc,
|
||||
initial_stride=initial_stride)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_RRDBNet(opt_net, opt):
|
||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
|
||||
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
|
||||
return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
||||
output_mode=output_mode, body_block=RRDB, scale=opt_net['scale'], growth_channels=gc,
|
||||
initial_stride=initial_stride)
|
|
@ -1,17 +1,19 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
|
||||
import torch.nn.functional as F
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \
|
||||
SiLU, UpconvBlock, ReferenceJoinBlock
|
||||
from models.switched_conv.switched_conv_util import save_attention_to_image_rgb
|
||||
import os
|
||||
from models.spinenet_arch import SpineNet
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, MultiConvBlock
|
||||
from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
|
||||
from models.switched_conv.switched_conv_util import save_attention_to_image_rgb
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
||||
# Doubles the input filter count.
|
||||
class HalvingProcessingBlock(nn.Module):
|
||||
|
@ -261,142 +263,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
return val
|
||||
|
||||
|
||||
# This class encapsulates an encoder based on an object detection network backbone whose purpose is to generated a
|
||||
# structured embedding encoding what is in an image patch. This embedding can then be used to perform structured
|
||||
# alterations to the underlying image.
|
||||
#
|
||||
# Caveat: Since this uses a pre-defined (and potentially pre-trained) SpineNet backbone, it has a minimum-supported
|
||||
# image size, which is 128x128. In order to use 64x64 patches, you must set interpolate_first=True. though this will
|
||||
# degrade quality.
|
||||
class BackboneEncoder(nn.Module):
|
||||
def __init__(self, interpolate_first=True, pretrained_backbone=None):
|
||||
super(BackboneEncoder, self).__init__()
|
||||
self.interpolate_first = interpolate_first
|
||||
|
||||
# Uses dual spinenets, one for the input patch and the other for the reference image.
|
||||
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True)
|
||||
self.ref_spine = SpineNet('49', in_channels=3, use_input_norm=True)
|
||||
|
||||
self.merge_process1 = ConvGnSilu(512, 512, kernel_size=1, activation=True, norm=False, bias=True)
|
||||
self.merge_process2 = ConvGnSilu(512, 384, kernel_size=1, activation=True, norm=True, bias=False)
|
||||
self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True)
|
||||
|
||||
if pretrained_backbone is not None:
|
||||
loaded_params = torch.load(pretrained_backbone)
|
||||
self.ref_spine.load_state_dict(loaded_params['state_dict'], strict=True)
|
||||
self.patch_spine.load_state_dict(loaded_params['state_dict'], strict=True)
|
||||
|
||||
# Returned embedding will have been reduced in size by a factor of 8 (4 if interpolate_first=True).
|
||||
# Output channels are always 256.
|
||||
# ex, 64x64 input with interpolate_first=True will result in tensor of shape [bx256x16x16]
|
||||
def forward(self, x, ref, ref_center_point):
|
||||
if self.interpolate_first:
|
||||
x = F.interpolate(x, scale_factor=2, mode="bicubic")
|
||||
# Don't interpolate ref - assume it is fed in at the proper resolution.
|
||||
# ref = F.interpolate(ref, scale_factor=2, mode="bicubic")
|
||||
|
||||
# [ref] will have a 'mask' channel which we cannot use with pretrained spinenet.
|
||||
ref = ref[:, :3, :, :]
|
||||
ref_emb = self.ref_spine(ref)[0]
|
||||
ref_code = gather_2d(ref_emb, ref_center_point // 8) # Divide by 8 to bring the center point to the correct location.
|
||||
|
||||
patch = self.patch_spine(x)[0]
|
||||
ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3])
|
||||
combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1))
|
||||
combined = self.merge_process2(combined)
|
||||
combined = self.merge_process3(combined)
|
||||
|
||||
return combined
|
||||
|
||||
|
||||
class BackboneEncoderNoRef(nn.Module):
|
||||
def __init__(self, interpolate_first=True, pretrained_backbone=None):
|
||||
super(BackboneEncoderNoRef, self).__init__()
|
||||
self.interpolate_first = interpolate_first
|
||||
|
||||
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True)
|
||||
|
||||
if pretrained_backbone is not None:
|
||||
loaded_params = torch.load(pretrained_backbone)
|
||||
self.patch_spine.load_state_dict(loaded_params['state_dict'], strict=True)
|
||||
|
||||
# Returned embedding will have been reduced in size by a factor of 8 (4 if interpolate_first=True).
|
||||
# Output channels are always 256.
|
||||
# ex, 64x64 input with interpolate_first=True will result in tensor of shape [bx256x16x16]
|
||||
def forward(self, x):
|
||||
if self.interpolate_first:
|
||||
x = F.interpolate(x, scale_factor=2, mode="bicubic")
|
||||
|
||||
patch = self.patch_spine(x)[0]
|
||||
return patch
|
||||
|
||||
|
||||
class BackboneSpinenetNoHead(nn.Module):
|
||||
def __init__(self):
|
||||
super(BackboneSpinenetNoHead, self).__init__()
|
||||
# Uses dual spinenets, one for the input patch and the other for the reference image.
|
||||
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=False, double_reduce_early=False)
|
||||
self.ref_spine = SpineNet('49', in_channels=4, use_input_norm=False, double_reduce_early=False)
|
||||
|
||||
self.merge_process1 = ConvGnSilu(512, 512, kernel_size=1, activation=True, norm=False, bias=True)
|
||||
self.merge_process2 = ConvGnSilu(512, 384, kernel_size=1, activation=True, norm=True, bias=False)
|
||||
self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True)
|
||||
|
||||
def forward(self, x, ref, ref_center_point):
|
||||
ref_emb = self.ref_spine(ref)[0]
|
||||
ref_code = gather_2d(ref_emb, ref_center_point // 4) # Divide by 8 to bring the center point to the correct location.
|
||||
|
||||
patch = self.patch_spine(x)[0]
|
||||
ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3])
|
||||
combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1))
|
||||
combined = self.merge_process2(combined)
|
||||
combined = self.merge_process3(combined)
|
||||
return combined
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, nf, downsample):
|
||||
super(ResBlock, self).__init__()
|
||||
nf_int = nf * 2
|
||||
nf_out = nf * 2 if downsample else nf
|
||||
stride = 2 if downsample else 1
|
||||
self.c1 = ConvGnSilu(nf, nf_int, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.c2 = ConvGnSilu(nf_int, nf_int, stride=stride, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.c3 = ConvGnSilu(nf_int, nf_out, kernel_size=3, bias=False, activation=False, norm=True)
|
||||
if downsample:
|
||||
self.downsample = ConvGnSilu(nf, nf_out, kernel_size=1, stride=stride, bias=False, activation=False, norm=True)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.act = SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
branch = self.c1(x)
|
||||
branch = self.c2(branch)
|
||||
branch = self.c3(branch)
|
||||
|
||||
if self.downsample:
|
||||
identity = self.downsample(identity)
|
||||
return self.act(identity + branch)
|
||||
|
||||
|
||||
class BackboneResnet(nn.Module):
|
||||
def __init__(self):
|
||||
super(BackboneResnet, self).__init__()
|
||||
self.initial_conv = ConvGnSilu(3, 64, kernel_size=7, bias=True, activation=False, norm=False)
|
||||
self.sequence = nn.Sequential(
|
||||
ResBlock(64, downsample=False),
|
||||
ResBlock(64, downsample=True),
|
||||
ResBlock(128, downsample=False),
|
||||
ResBlock(128, downsample=True),
|
||||
ResBlock(256, downsample=False),
|
||||
ResBlock(256, downsample=False))
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.initial_conv(x)
|
||||
return self.sequence(fea)
|
||||
|
||||
|
||||
# Computes a linear latent by performing processing on the reference image and returning the filters of a single point,
|
||||
# which should be centered on the image patch being processed.
|
||||
#
|
||||
|
@ -706,3 +572,23 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
|
||||
@register_model
|
||||
def register_ConfigurableSwitchedResidualGenerator2(opt_net, opt):
|
||||
return ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'],
|
||||
switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
switch_processing_layers=opt_net[
|
||||
'switch_processing_layers'],
|
||||
trans_counts=opt_net['trans_counts'],
|
||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'],
|
||||
trans_layers=opt_net['trans_layers'],
|
||||
transformation_filters=opt_net['transformation_filters'],
|
||||
attention_norm=opt_net['attention_norm'],
|
||||
initial_temp=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'],
|
||||
heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale,
|
||||
add_scalable_noise_to_transforms=opt_net['add_noise'],
|
||||
for_video=opt_net['for_video'])
|
|
@ -10,7 +10,8 @@ from kornia import filters
|
|||
from torch import nn
|
||||
|
||||
from data.byol_attachment import RandomApply
|
||||
from utils.util import checkpoint
|
||||
from trainer.networks import register_model, create_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
def default(val, def_val):
|
||||
|
@ -269,3 +270,11 @@ class BYOL(nn.Module):
|
|||
|
||||
loss = loss_one + loss_two
|
||||
return loss.mean()
|
||||
|
||||
|
||||
@register_model
|
||||
def register_byol(opt_net, opt):
|
||||
subnet = create_model(opt, opt_net['subnet'])
|
||||
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
|
||||
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
|
|
@ -7,6 +7,7 @@ from torch import nn
|
|||
from data.byol_attachment import reconstructed_shared_regions
|
||||
from models.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
|
||||
update_moving_average
|
||||
from trainer.networks import create_model, register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
# loss function
|
||||
|
@ -179,3 +180,10 @@ class StructuralBYOL(nn.Module):
|
|||
enc = self.online_encoder(image)
|
||||
proj = self.online_predictor(enc)
|
||||
return enc, proj
|
||||
|
||||
@register_model
|
||||
def register_structural_byol(opt_net, opt):
|
||||
subnet = create_model(opt, opt_net['subnet'])
|
||||
return StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
|
||||
freeze_until=opt_get(opt_net, ['freeze_until'], 0))
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.RRDBNet_arch import RRDB, RRDBWithBypass
|
||||
from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
|
||||
import torch.nn.functional as F
|
||||
from models.SwitchedResidualGenerator_arch import gather_2d
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
|
@ -519,6 +517,7 @@ class RefDiscriminatorVgg128(nn.Module):
|
|||
def forward(self, x, ref, ref_center_point):
|
||||
ref = self.ref_head(ref)
|
||||
ref_center_point = ref_center_point // 16
|
||||
from models.SwitchedResidualGenerator_arch import gather_2d
|
||||
ref_vector = gather_2d(ref, ref_center_point)
|
||||
ref_vector = self.ref_linear(ref_vector)
|
||||
|
||||
|
|
0
codes/models/glean/__init__.py
Normal file
0
codes/models/glean/__init__.py
Normal file
|
@ -10,6 +10,7 @@ from models.arch_util import ConvGnLelu
|
|||
# Produces a convolutional feature (`f`) and a reduced feature map with double the filters.
|
||||
from models.glean.stylegan2_latent_bank import Stylegan2LatentBank
|
||||
from models.stylegan.stylegan2_rosinality import EqualLinear
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, sequential_checkpoint
|
||||
|
||||
|
||||
|
@ -108,3 +109,8 @@ class GleanGenerator(nn.Module):
|
|||
rrdb_fea, conv_fea, latents = self.encoder(x)
|
||||
latent_bank_fea = self.latent_bank(conv_fea, latents)
|
||||
return self.decoder(rrdb_fea, latent_bank_fea)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_glean(opt_net, opt):
|
||||
return GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])
|
||||
|
|
|
@ -11,6 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
|||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
model_urls = {
|
||||
|
@ -188,3 +189,8 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
|||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_resnet52(opt_net, opt):
|
||||
return resnet50(pretrained=opt_net['pretrained'])
|
||||
|
|
|
@ -7,6 +7,7 @@ from torch.nn.init import kaiming_normal
|
|||
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
from models.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
||||
def constant_init(module, val, bias=0):
|
||||
|
@ -359,3 +360,13 @@ class SpinenetWithLogits(SpineNet):
|
|||
def forward(self, x):
|
||||
fea = super().forward(x)[self.output_to_attach]
|
||||
return self.tail(fea)
|
||||
|
||||
@register_model
|
||||
def register_spinenet(opt_net, opt):
|
||||
return SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||
|
||||
|
||||
@register_model
|
||||
def register_spinenet_with_logits(opt_net, opt):
|
||||
return SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
|
||||
in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||
|
|
|
@ -4,6 +4,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import models.srflow.module_util as mutil
|
||||
from models.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
@ -240,3 +241,18 @@ class RRDBLatentWrapper(nn.Module):
|
|||
fea = torch.cat(blocklist, dim=1)
|
||||
fea = self.postprocess(fea)
|
||||
return fea
|
||||
|
||||
|
||||
@register_model
|
||||
def register_rrdb_latent_wrapper(opt_net, opt):
|
||||
return RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
|
||||
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'],
|
||||
pretrain_rrdb_path=opt_net['pretrain_path'])
|
||||
|
||||
|
||||
@register_model
|
||||
def register_rrdb_srflow(opt_net, opt):
|
||||
return RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||
initial_conv_stride=opt_net['initial_stride'])
|
|
@ -8,6 +8,7 @@ from models.srflow.RRDBNet_arch import RRDBNet
|
|||
from models.srflow.FlowUpsamplerNet import FlowUpsamplerNet
|
||||
import models.srflow.thops as thops
|
||||
import models.srflow.flow as flow
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
@ -166,3 +167,9 @@ class SRFlowNet(nn.Module):
|
|||
logdet=logdet)
|
||||
|
||||
return x, logdet, lr_enc['out']
|
||||
|
||||
|
||||
@register_model
|
||||
def register_srflow(opt_net, opt):
|
||||
return SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||
K=opt_net['K'], opt=opt)
|
||||
|
|
|
@ -11,6 +11,7 @@ from collections import OrderedDict
|
|||
from models.SwitchedResidualGenerator_arch import HalvingProcessingBlock, ConfigurableSwitchComputer
|
||||
from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, MultiConvBlock
|
||||
from models.switched_conv.switched_conv import BareConvSwitch, AttentionNorm
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
|
@ -204,3 +205,19 @@ class Interpolate(nn.Module):
|
|||
def forward(self, x):
|
||||
return F.interpolate(x, scale_factor=self.factor, mode=self.mode)
|
||||
|
||||
@register_model
|
||||
def register_srg2classic(opt_net, opt):
|
||||
return ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'],
|
||||
switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
switch_processing_layers=opt_net['switch_processing_layers'],
|
||||
trans_counts=opt_net['trans_counts'],
|
||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'],
|
||||
trans_layers=opt_net['trans_layers'],
|
||||
transformation_filters=opt_net['transformation_filters'],
|
||||
initial_temp=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'],
|
||||
heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale,
|
||||
add_scalable_noise_to_transforms=opt_net['add_noise'])
|
|
@ -17,6 +17,7 @@ from torch import nn
|
|||
from torch.autograd import grad as torch_grad
|
||||
from vector_quantize_pytorch import VectorQuantize
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
try:
|
||||
|
@ -893,3 +894,12 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss):
|
|||
|
||||
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
|
||||
return 0
|
||||
|
||||
|
||||
@register_model
|
||||
def register_stylegan2_lucidrains(opt_net, opt):
|
||||
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
|
||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||
return StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
||||
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
||||
attn_layers=attn)
|
||||
|
|
15
codes/models/tecogan/flownet2.py
Normal file
15
codes/models/tecogan/flownet2.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import munch
|
||||
import torch
|
||||
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
||||
@register_model
|
||||
def register_flownet2(opt_net):
|
||||
from models.flownet2.models import FlowNet2
|
||||
ld = 'load_path' in opt_net.keys()
|
||||
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
||||
netG = FlowNet2(args)
|
||||
if ld:
|
||||
sd = torch.load(opt_net['load_path'])
|
||||
netG.load_state_dict(sd['state_dict'])
|
|
@ -4,6 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import sequential_checkpoint
|
||||
from models.arch_util import ConvGnSilu, make_layer
|
||||
|
||||
|
@ -71,3 +72,8 @@ class TecoGen(nn.Module):
|
|||
|
||||
def get_debug_values(self, step, net_name):
|
||||
return {'branch_std': self.join.std()}
|
||||
|
||||
|
||||
@register_model
|
||||
def register_tecogen(opt_net, opt):
|
||||
return TecoGen(opt_net['nf'], opt_net['scale'])
|
|
@ -3,6 +3,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from trainer.injectors import Injector
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
|
@ -147,3 +148,8 @@ class iGPT2(nn.Module):
|
|||
|
||||
return logits, x
|
||||
|
||||
|
||||
@register_model
|
||||
def register_igpt2(opt_net, opt):
|
||||
return iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2,
|
||||
opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
|
||||
|
|
|
@ -57,7 +57,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
new_net = None
|
||||
if net['type'] == 'generator':
|
||||
if new_net is None:
|
||||
new_net = networks.define_G(opt, net, opt['scale']).to(self.device)
|
||||
new_net = networks.create_model(opt, net, opt['scale']).to(self.device)
|
||||
self.netsG[name] = new_net
|
||||
elif net['type'] == 'discriminator':
|
||||
if new_net is None:
|
||||
|
|
|
@ -1,140 +1,79 @@
|
|||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import pkgutil
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from inspect import isfunction, getmembers
|
||||
|
||||
import munch
|
||||
import torch
|
||||
import torchvision
|
||||
from munch import munchify
|
||||
import models.stylegan.stylegan2_lucidrains as stylegan2
|
||||
|
||||
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||
import models.RRDBNet_arch as RRDBNet_arch
|
||||
import models.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||
import models.discriminator_vgg_arch as SRGAN_arch
|
||||
import models.feature_arch as feature_arch
|
||||
from models import srg2_classic
|
||||
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||
from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
||||
from models.tecogan.teco_resgen import TecoGen
|
||||
from utils.util import opt_get
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
# Generator
|
||||
def define_G(opt, opt_net, scale=None):
|
||||
if scale is None:
|
||||
scale = opt['scale']
|
||||
which_model = opt_net['which_model_G']
|
||||
|
||||
if 'RRDBNet' in which_model:
|
||||
if which_model == 'RRDBNetBypass':
|
||||
block = RRDBNet_arch.RRDBWithBypass
|
||||
elif which_model == 'RRDBNetLambda':
|
||||
from models.lambda_rrdb import LambdaRRDB
|
||||
block = LambdaRRDB
|
||||
else:
|
||||
block = RRDBNet_arch.RRDB
|
||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
|
||||
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
|
||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
||||
output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc,
|
||||
initial_stride=initial_stride)
|
||||
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||
transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
|
||||
for_video=opt_net['for_video'])
|
||||
elif which_model == "srg2classic":
|
||||
netG = srg2_classic.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||
transformation_filters=opt_net['transformation_filters'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == "flownet2":
|
||||
from models.flownet2 import FlowNet2
|
||||
ld = 'load_path' in opt_net.keys()
|
||||
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
||||
netG = FlowNet2(args)
|
||||
if ld:
|
||||
sd = torch.load(opt_net['load_path'])
|
||||
netG.load_state_dict(sd['state_dict'])
|
||||
elif which_model == "backbone_encoder":
|
||||
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||
elif which_model == "backbone_encoder_no_ref":
|
||||
netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||
elif which_model == "backbone_encoder_no_head":
|
||||
netG = SwitchedGen_arch.BackboneSpinenetNoHead()
|
||||
elif which_model == "backbone_resnet":
|
||||
netG = SwitchedGen_arch.BackboneResnet()
|
||||
elif which_model == "tecogen":
|
||||
netG = TecoGen(opt_net['nf'], opt_net['scale'])
|
||||
elif which_model == 'stylegan2':
|
||||
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
|
||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
||||
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
||||
attn_layers=attn)
|
||||
elif which_model == 'srflow':
|
||||
from models.srflow import SRFlowNet_arch
|
||||
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||
K=opt_net['K'], opt=opt)
|
||||
elif which_model == 'rrdb_latent_wrapper':
|
||||
from models.srflow.RRDBNet_arch import RRDBLatentWrapper
|
||||
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
|
||||
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
|
||||
elif which_model == 'rrdb_centipede':
|
||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
|
||||
headless=True, output_mode=output_mode)
|
||||
elif which_model == 'rrdb_srflow':
|
||||
from models.srflow.RRDBNet_arch import RRDBNet
|
||||
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||
initial_conv_stride=opt_net['initial_stride'])
|
||||
elif which_model == 'igpt2':
|
||||
from models.transformers.igpt.gpt2 import iGPT2
|
||||
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
|
||||
elif which_model == 'byol':
|
||||
from models.byol.byol_model_wrapper import BYOL
|
||||
subnet = define_G(opt, opt_net['subnet'])
|
||||
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
|
||||
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
|
||||
elif which_model == 'structural_byol':
|
||||
from models.byol.byol_structural import StructuralBYOL
|
||||
subnet = define_G(opt, opt_net['subnet'])
|
||||
netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
|
||||
freeze_until=opt_get(opt_net, ['freeze_until'], 0))
|
||||
elif which_model == 'spinenet':
|
||||
from models.spinenet_arch import SpineNet
|
||||
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||
elif which_model == 'spinenet_with_logits':
|
||||
from models.spinenet_arch import SpinenetWithLogits
|
||||
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
|
||||
in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||
elif which_model == 'resnet52':
|
||||
from models.resnet_with_checkpointing import resnet50
|
||||
netG = resnet50(pretrained=opt_net['pretrained'])
|
||||
elif which_model == 'glean':
|
||||
from models.glean.glean import GleanGenerator
|
||||
netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])
|
||||
class RegisteredModelNameError(Exception):
|
||||
def __init__(self, name_error):
|
||||
super().__init__(f'Registered DLAS modules must start with `register_`. Incorrect registration: {name_error}')
|
||||
|
||||
|
||||
# Decorator that allows API clients to show DLAS how to build a nn.Module from an opt dict.
|
||||
# Functions with this decorator should have a specific naming format:
|
||||
# `register_<name>` where <name> is the name that will be used in configuration files to reference this model.
|
||||
# Functions with this decorator are expected to take a single argument:
|
||||
# - opt: A dict with the configuration options for building the module.
|
||||
# They should return:
|
||||
# - A torch.nn.Module object for the model being defined.
|
||||
def register_model(func):
|
||||
if func.__name__.startswith("register_"):
|
||||
func._dlas_model_name = func.__name__[9:]
|
||||
assert func._dlas_model_name
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
return netG
|
||||
raise RegisteredModelNameError(func.__name__)
|
||||
func._dlas_registered_model = True
|
||||
return func
|
||||
|
||||
|
||||
def find_registered_model_fns(base_path='models'):
|
||||
found_fns = {}
|
||||
module_iter = pkgutil.walk_packages([base_path])
|
||||
for mod in module_iter:
|
||||
if mod.ispkg:
|
||||
EXCLUSION_LIST = ['flownet2']
|
||||
if mod.name not in EXCLUSION_LIST:
|
||||
found_fns.update(find_registered_model_fns(f'{base_path}/{mod.name}'))
|
||||
else:
|
||||
mod_name = f'{base_path}/{mod.name}'.replace('/', '.')
|
||||
importlib.import_module(mod_name)
|
||||
for mod_fn in getmembers(sys.modules[mod_name], isfunction):
|
||||
if hasattr(mod_fn[1], "_dlas_registered_model"):
|
||||
found_fns[mod_fn[1]._dlas_model_name] = mod_fn[1]
|
||||
return found_fns
|
||||
|
||||
|
||||
class CreateModelError(Exception):
|
||||
def __init__(self, name, available):
|
||||
super().__init__(f'Could not find the specified model name: {name}. Tip: If your model is in a'
|
||||
f' subdirectory, that directory must contain an __init__.py to be scanned. Available models:'
|
||||
f'{available}')
|
||||
|
||||
|
||||
def create_model(opt, opt_net, scale=None):
|
||||
which_model = opt_net['which_model']
|
||||
# For backwards compatibility.
|
||||
if not which_model:
|
||||
which_model = opt_net['which_model_G']
|
||||
if not which_model:
|
||||
which_model = opt_net['which_model_D']
|
||||
registered_fns = find_registered_model_fns()
|
||||
if which_model not in registered_fns.keys():
|
||||
raise CreateModelError(which_model, list(registered_fns.keys()))
|
||||
return registered_fns[which_model](opt_net, opt)
|
||||
|
||||
|
||||
class GradDiscWrapper(torch.nn.Module):
|
||||
|
|
|
@ -2,7 +2,7 @@ import argparse
|
|||
import functools
|
||||
import torch
|
||||
from utils import options as option
|
||||
from trainer.networks import define_G
|
||||
from trainer.networks import create_model
|
||||
|
||||
|
||||
class TracedModule:
|
||||
|
@ -96,7 +96,7 @@ if __name__ == "__main__":
|
|||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
netG = define_G(opt)
|
||||
netG = create_model(opt)
|
||||
dummyInput = torch.rand(1,3,32,32)
|
||||
|
||||
mode = 'onnx'
|
||||
|
|
Loading…
Reference in New Issue
Block a user