diff --git a/.idea/vcs.xml b/.idea/vcs.xml
index fc78291c..8d59ed0d 100644
--- a/.idea/vcs.xml
+++ b/.idea/vcs.xml
@@ -3,6 +3,7 @@
+
\ No newline at end of file
diff --git a/codes/models/ProgressiveSrg_arch.py b/codes/models/ProgressiveSrg_arch.py
deleted file mode 100644
index 7e5c9dde..00000000
--- a/codes/models/ProgressiveSrg_arch.py
+++ /dev/null
@@ -1,275 +0,0 @@
-import models.SwitchedResidualGenerator_arch as srg
-import torch
-import torch.nn as nn
-from switched_conv.switched_conv_util import save_attention_to_image
-from switched_conv.switched_conv import compute_attention_specificity
-from models.arch_util import ConvGnLelu, ExpansionBlock, MultiConvBlock
-import functools
-import torch.nn.functional as F
-
-# Some notes about this new architecture:
-# 1) Discriminator is going to need to get update_for_step() called.
-# 2) Not sure if pixgan part of discriminator is going to work properly, make sure to test at multiple add levels.
-# 3) Also not sure if growth modules will be properly saved/trained, be sure to test this.
-# 4) start_step will need to get set properly when constructing these models, even when resuming - OR another method needs to be added to resume properly.
-
-class GrowingSRGBase(nn.Module):
- def __init__(self, progressive_step_schedule, switch_reductions, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts,
- trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, upsample_factor=1,
- add_scalable_noise_to_transforms=False, start_step=0):
- super(GrowingSRGBase, self).__init__()
- switches = []
- self.initial_conv = ConvGnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
- self.upconv1 = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
- self.upconv2 = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
- self.hr_conv = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
- self.final_conv = ConvGnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
-
- self.switch_filters = switch_filters
- self.switch_processing_layers = switch_processing_layers
- self.trans_layers = trans_layers
- self.transformation_filters = transformation_filters
- self.progressive_schedule = progressive_step_schedule
- self.switch_reductions = switch_reductions # This lists the reductions for all switches (even ones not activated yet).
- self.growth_fade_in_per_step = 1 / growth_fade_in_steps
- self.transformation_counts = trans_counts
- self.init_temperature = initial_temp
- self.final_temperature_step = final_temperature_step
- self.attentions = None
- self.upsample_factor = upsample_factor
- self.add_noise_to_transform = add_scalable_noise_to_transforms
- self.start_step = start_step
- self.latest_step = start_step
- self.fades = []
- self.counter = 0
- assert self.upsample_factor == 2 or self.upsample_factor == 4
-
- switches = []
- for i, (step, reductions) in enumerate(zip(progressive_step_schedule, switch_reductions)):
- multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
- reductions, self.switch_processing_layers, self.transformation_counts)
- pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
- bias=False, weight_init_factor=.1)
- transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
- self.transformation_filters, kernel_size=3, depth=self.trans_layers,
- weight_init_factor=.1)
- switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
- pre_transform_block=pretransform_fn,
- transform_block=transform_fn,
- transform_count=self.transformation_counts, init_temp=self.init_temperature,
- add_scalable_noise_to_transforms=self.add_noise_to_transform,
- attention_norm=False))
- self.progressive_switches = nn.ModuleList(switches)
-
- def get_param_groups(self):
- param_groups = []
- base_param_group = []
- for k, v in self.named_parameters():
- if "progressive_switches" not in k and v.requires_grad:
- base_param_group.append(v)
- param_groups.append({'params': base_param_group})
- for i, sw in enumerate(self.progressive_switches):
- sw_param_group = []
- for k, v in sw.named_parameters():
- if v.requires_grad:
- sw_param_group.append(v)
- param_groups.append({'params': sw_param_group})
- return param_groups
-
- # This is a hacky way of modifying the underlying model while training. Since changing the model means changing
- # the optimizer and the scheduler, these things are fed in. For ProgressiveSrg, this function adds an additional
-# switch to the end of the chain with depth=3 and an online time set at the end fo the function.
- def update_model(self, opt, sched):
- multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
- 3, self.switch_processing_layers, self.transformation_counts)
- pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
- bias=False, weight_init_factor=.1)
- transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
- self.transformation_filters, kernel_size=3, depth=self.trans_layers,
- weight_init_factor=.1)
- new_sw = srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
- pre_transform_block=pretransform_fn,
- transform_block=transform_fn,
- transform_count=self.transformation_counts, init_temp=self.init_temperature,
- add_scalable_noise_to_transforms=self.add_noise_to_transform,
- attention_norm=False).to('cuda')
- self.progressive_switches.append(new_sw)
- new_sw_param_group = []
- for k, v in new_sw.named_parameters():
- if v.requires_grad:
- new_sw_param_group.append(v)
- opt.add_param_group({'params': new_sw_param_group})
- self.progressive_schedule.append(150000)
- sched.group_starts.append(150000)
-
- def get_progressive_starts(self):
- # The base param group starts at step 0, the rest are defined via progressive_switches.
- return [0] + self.progressive_schedule
-
- # This method turns requires_grad on and off for different switches, allowing very large models to be trained while
- # using less memory. When used in conjunction with gradient accumulation, it becomes a form of model parallelism.
- # controls the proportion of switches that are enabled. 1/groups will be enabled.
- # Switches that are younger than 40000 steps are not eligible to be turned off.
- def do_switched_grad(self, groups=1):
- # If requires_grad is already disabled, don't bother.
- if not self.initial_conv.conv.weight.requires_grad or groups == 1:
- return
- self.counter = (self.counter + 1) % groups
- enabled = []
- for i, sw in enumerate(self.progressive_switches):
- if self.latest_step - self.progressive_schedule[i] > 40000 and i % groups != self.counter:
- for p in sw.parameters():
- p.requires_grad = False
- else:
- enabled.append(i)
- for p in sw.parameters():
- p.requires_grad = True
-
- def forward(self, x):
- self.do_switched_grad(2)
-
- x = self.initial_conv(x)
-
- self.attentions = []
- self.fades = []
- self.enabled_switches = 0
- for i, sw in enumerate(self.progressive_switches):
- fade_in = 1 if self.progressive_schedule[i] == 0 else 0
- if self.latest_step > 0 and self.progressive_schedule[i] != 0:
- switch_age = self.latest_step - self.progressive_schedule[i]
- fade_in = min(1, switch_age * self.growth_fade_in_per_step)
-
- if fade_in > 0:
- self.enabled_switches += 1
- x, att = sw.forward(x, True, fixed_scale=fade_in)
- self.attentions.append(att)
- self.fades.append(fade_in)
-
- x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
- if self.upsample_factor > 2:
- x = F.interpolate(x, scale_factor=2, mode="nearest")
- x = self.upconv2(x)
- x = self.final_conv(self.hr_conv(x))
- return x, x
-
- def update_for_step(self, step, experiments_path='.'):
- self.latest_step = step + self.start_step
-
- # Set the temperature of the switches, per-layer.
- for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.progressive_switches)):
- temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
- sw.set_temperature(min(self.init_temperature,
- max(self.init_temperature - temp_loss_per_step * (step - first_step), 1)))
-
- # Save attention images.
- if self.attentions is not None and step % 50 == 0:
- [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))]
-
- def get_debug_values(self, step):
- mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
- means = [i[0] for i in mean_hists]
- hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
- val = {}
- for i in range(len(means)):
- val["switch_%i_specificity" % (i,)] = means[i]
- val["switch_%i_histogram" % (i,)] = hists[i]
- val["switch_%i_temperature" % (i,)] = self.progressive_switches[i].switch.temperature
- for i, f in enumerate(self.fades):
- val["switch_%i_fade" % (i,)] = f
- val["enabled_switches"] = self.enabled_switches
- return val
-
-class DiscriminatorDownsample(nn.Module):
- def __init__(self, base_filters, end_filters):
- self.conv0 = ConvGnLelu(base_filters, end_filters, kernel_size=3, bias=False)
- self.conv1 = ConvGnLelu(end_filters, end_filters, kernel_size=3, stride=2, bias=False)
-
- def forward(self, x):
- return self.conv1(self.conv0(x))
-
-
-class DiscriminatorUpsample(nn.Module):
- def __init__(self, base_filters, end_filters):
- self.up = ExpansionBlock(base_filters, end_filters, block=ConvGnLelu)
- self.proc = ConvGnLelu(end_filters, end_filters, bias=False)
- self.collapse = ConvGnLelu(end_filters, 1, bias=True, norm=False, activation=False)
-
- def forward(self, x, ff):
- x = self.up1(x, ff)
- return x, self.collapse1(self.proc1(x))
-
-
-class GrowingUnetDiscBase(nn.Module):
- def __init__(self, nf, growing_schedule, growth_fade_in_steps, start_step=0):
- super(GrowingUnetDiscBase, self).__init__()
- # [64, 128, 128]
- self.conv0_0 = ConvGnLelu(3, nf, kernel_size=3, bias=True, activation=False)
- self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
- # [64, 64, 64]
- self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
- self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
-
- self.down_base = DiscriminatorDownsample(nf * 2, nf * 4)
- self.up_base = DiscriminatorUpsample(nf * 4, nf * 2)
-
- self.progressive_schedule = growing_schedule
- self.growth_fade_in_per_step = 1 / growth_fade_in_steps
- self.pnf = nf * 4
- self.downsamples = nn.ModuleList([])
- self.upsamples = nn.ModuleList([])
-
- for i, step in enumerate(growing_schedule):
- if step >= start_step:
- self.add_layer(i + 1)
-
- def add_layer(self):
- self.downsamples.append(DiscriminatorDownsample(self.pnf, self.pnf))
- self.upsamples.append(DiscriminatorUpsample(self.pnf, self.pnf))
-
- def update_for_step(self, step):
- self.latest_step = step
-
- # Add any new layers as spelled out by the schedule.
- if step != 0:
- for i, s in enumerate(self.progressive_schedule):
- if s == step:
- self.add_layer(i + 1)
-
- def forward(self, x, output_feature_vector=False):
- x = self.conv0_0(x)
- x = self.conv0_1(x)
- x = self.conv1_0(x)
- x = self.conv1_1(x)
- base_fea = self.down_base(x)
- x = base_fea
-
- skips = []
- for down in self.downsamples:
- x = down(x)
- skips.append(x)
-
- losses = []
- for i, up in enumerate(self.upsamples):
- j = i + 1
- x, loss = up(x, skips[-j])
- losses.append(loss)
-
- # This variant averages the outputs of the U-net across the upsamples, weighting the contribution
- # to the average less for newly growing levels.
- _, base_loss = self.up_base(x, base_fea)
- res = base_loss.shape[2:]
-
- mean_weight = 1
- for i, l in enumerate(losses):
- fade_in = 1
- if self.latest_step > 0 and self.progressive_schedule[i] != 0:
- disc_age = self.latest_step - self.progressive_schedule[i]
- fade_in = min(1, disc_age * self.growth_fade_in_per_step)
- mean_weight += fade_in
- base_loss += F.interpolate(l, size=res, mode="bilinear", align_corners=False) * fade_in
- base_loss /= mean_weight
-
- return base_loss.view(-1, 1)
-
- def pixgan_parameters(self):
- return 1, 4
\ No newline at end of file
diff --git a/codes/models/RRDBNet_arch.py b/codes/models/RRDBNet_arch.py
index 6a43f158..ba4d576f 100644
--- a/codes/models/RRDBNet_arch.py
+++ b/codes/models/RRDBNet_arch.py
@@ -6,6 +6,7 @@ import torch.nn.functional as F
import torchvision
from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
+from trainer.networks import register_model
from utils.util import checkpoint, sequential_checkpoint
@@ -370,3 +371,27 @@ class RRDBDiscriminator(nn.Module):
if self.pred_ is not None:
self.pred_ = F.sigmoid(self.pred_)
torchvision.utils.save_image(self.pred_.cpu().float(), os.path.join(path, "%i_predictions.png" % (step,)))
+
+
+@register_model
+def register_RRDBNetBypass(opt_net, opt):
+ additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
+ output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
+ gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
+ initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
+ return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
+ mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
+ output_mode=output_mode, body_block=RRDBWithBypass, scale=opt_net['scale'], growth_channels=gc,
+ initial_stride=initial_stride)
+
+
+@register_model
+def register_RRDBNet(opt_net, opt):
+ additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
+ output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
+ gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
+ initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
+ return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
+ mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
+ output_mode=output_mode, body_block=RRDB, scale=opt_net['scale'], growth_channels=gc,
+ initial_stride=initial_stride)
\ No newline at end of file
diff --git a/codes/models/SwitchedResidualGenerator_arch.py b/codes/models/SwitchedResidualGenerator_arch.py
index feb20733..433f3621 100644
--- a/codes/models/SwitchedResidualGenerator_arch.py
+++ b/codes/models/SwitchedResidualGenerator_arch.py
@@ -1,17 +1,19 @@
-import torch
-from torch import nn
-from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
-import torch.nn.functional as F
import functools
-from collections import OrderedDict
-from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \
- SiLU, UpconvBlock, ReferenceJoinBlock
-from models.switched_conv.switched_conv_util import save_attention_to_image_rgb
import os
-from models.spinenet_arch import SpineNet
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
import torchvision
+from torch import nn
+
+from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, MultiConvBlock
+from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm
+from models.switched_conv.switched_conv_util import save_attention_to_image_rgb
+from trainer.networks import register_model
from utils.util import checkpoint
+
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
# Doubles the input filter count.
class HalvingProcessingBlock(nn.Module):
@@ -261,142 +263,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
return val
-# This class encapsulates an encoder based on an object detection network backbone whose purpose is to generated a
-# structured embedding encoding what is in an image patch. This embedding can then be used to perform structured
-# alterations to the underlying image.
-#
-# Caveat: Since this uses a pre-defined (and potentially pre-trained) SpineNet backbone, it has a minimum-supported
-# image size, which is 128x128. In order to use 64x64 patches, you must set interpolate_first=True. though this will
-# degrade quality.
-class BackboneEncoder(nn.Module):
- def __init__(self, interpolate_first=True, pretrained_backbone=None):
- super(BackboneEncoder, self).__init__()
- self.interpolate_first = interpolate_first
-
- # Uses dual spinenets, one for the input patch and the other for the reference image.
- self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True)
- self.ref_spine = SpineNet('49', in_channels=3, use_input_norm=True)
-
- self.merge_process1 = ConvGnSilu(512, 512, kernel_size=1, activation=True, norm=False, bias=True)
- self.merge_process2 = ConvGnSilu(512, 384, kernel_size=1, activation=True, norm=True, bias=False)
- self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True)
-
- if pretrained_backbone is not None:
- loaded_params = torch.load(pretrained_backbone)
- self.ref_spine.load_state_dict(loaded_params['state_dict'], strict=True)
- self.patch_spine.load_state_dict(loaded_params['state_dict'], strict=True)
-
- # Returned embedding will have been reduced in size by a factor of 8 (4 if interpolate_first=True).
- # Output channels are always 256.
- # ex, 64x64 input with interpolate_first=True will result in tensor of shape [bx256x16x16]
- def forward(self, x, ref, ref_center_point):
- if self.interpolate_first:
- x = F.interpolate(x, scale_factor=2, mode="bicubic")
- # Don't interpolate ref - assume it is fed in at the proper resolution.
- # ref = F.interpolate(ref, scale_factor=2, mode="bicubic")
-
- # [ref] will have a 'mask' channel which we cannot use with pretrained spinenet.
- ref = ref[:, :3, :, :]
- ref_emb = self.ref_spine(ref)[0]
- ref_code = gather_2d(ref_emb, ref_center_point // 8) # Divide by 8 to bring the center point to the correct location.
-
- patch = self.patch_spine(x)[0]
- ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3])
- combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1))
- combined = self.merge_process2(combined)
- combined = self.merge_process3(combined)
-
- return combined
-
-
-class BackboneEncoderNoRef(nn.Module):
- def __init__(self, interpolate_first=True, pretrained_backbone=None):
- super(BackboneEncoderNoRef, self).__init__()
- self.interpolate_first = interpolate_first
-
- self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True)
-
- if pretrained_backbone is not None:
- loaded_params = torch.load(pretrained_backbone)
- self.patch_spine.load_state_dict(loaded_params['state_dict'], strict=True)
-
- # Returned embedding will have been reduced in size by a factor of 8 (4 if interpolate_first=True).
- # Output channels are always 256.
- # ex, 64x64 input with interpolate_first=True will result in tensor of shape [bx256x16x16]
- def forward(self, x):
- if self.interpolate_first:
- x = F.interpolate(x, scale_factor=2, mode="bicubic")
-
- patch = self.patch_spine(x)[0]
- return patch
-
-
-class BackboneSpinenetNoHead(nn.Module):
- def __init__(self):
- super(BackboneSpinenetNoHead, self).__init__()
- # Uses dual spinenets, one for the input patch and the other for the reference image.
- self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=False, double_reduce_early=False)
- self.ref_spine = SpineNet('49', in_channels=4, use_input_norm=False, double_reduce_early=False)
-
- self.merge_process1 = ConvGnSilu(512, 512, kernel_size=1, activation=True, norm=False, bias=True)
- self.merge_process2 = ConvGnSilu(512, 384, kernel_size=1, activation=True, norm=True, bias=False)
- self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True)
-
- def forward(self, x, ref, ref_center_point):
- ref_emb = self.ref_spine(ref)[0]
- ref_code = gather_2d(ref_emb, ref_center_point // 4) # Divide by 8 to bring the center point to the correct location.
-
- patch = self.patch_spine(x)[0]
- ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3])
- combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1))
- combined = self.merge_process2(combined)
- combined = self.merge_process3(combined)
- return combined
-
-
-class ResBlock(nn.Module):
- def __init__(self, nf, downsample):
- super(ResBlock, self).__init__()
- nf_int = nf * 2
- nf_out = nf * 2 if downsample else nf
- stride = 2 if downsample else 1
- self.c1 = ConvGnSilu(nf, nf_int, kernel_size=3, bias=False, activation=True, norm=True)
- self.c2 = ConvGnSilu(nf_int, nf_int, stride=stride, kernel_size=3, bias=False, activation=True, norm=True)
- self.c3 = ConvGnSilu(nf_int, nf_out, kernel_size=3, bias=False, activation=False, norm=True)
- if downsample:
- self.downsample = ConvGnSilu(nf, nf_out, kernel_size=1, stride=stride, bias=False, activation=False, norm=True)
- else:
- self.downsample = None
- self.act = SiLU()
-
- def forward(self, x):
- identity = x
- branch = self.c1(x)
- branch = self.c2(branch)
- branch = self.c3(branch)
-
- if self.downsample:
- identity = self.downsample(identity)
- return self.act(identity + branch)
-
-
-class BackboneResnet(nn.Module):
- def __init__(self):
- super(BackboneResnet, self).__init__()
- self.initial_conv = ConvGnSilu(3, 64, kernel_size=7, bias=True, activation=False, norm=False)
- self.sequence = nn.Sequential(
- ResBlock(64, downsample=False),
- ResBlock(64, downsample=True),
- ResBlock(128, downsample=False),
- ResBlock(128, downsample=True),
- ResBlock(256, downsample=False),
- ResBlock(256, downsample=False))
-
- def forward(self, x):
- fea = self.initial_conv(x)
- return self.sequence(fea)
-
-
# Computes a linear latent by performing processing on the reference image and returning the filters of a single point,
# which should be centered on the image patch being processed.
#
@@ -705,4 +571,24 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
- return val
\ No newline at end of file
+ return val
+
+@register_model
+def register_ConfigurableSwitchedResidualGenerator2(opt_net, opt):
+ return ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'],
+ switch_filters=opt_net['switch_filters'],
+ switch_reductions=opt_net['switch_reductions'],
+ switch_processing_layers=opt_net[
+ 'switch_processing_layers'],
+ trans_counts=opt_net['trans_counts'],
+ trans_kernel_sizes=opt_net['trans_kernel_sizes'],
+ trans_layers=opt_net['trans_layers'],
+ transformation_filters=opt_net['transformation_filters'],
+ attention_norm=opt_net['attention_norm'],
+ initial_temp=opt_net['temperature'],
+ final_temperature_step=opt_net['temperature_final_step'],
+ heightened_temp_min=opt_net['heightened_temp_min'],
+ heightened_final_step=opt_net['heightened_final_step'],
+ upsample_factor=scale,
+ add_scalable_noise_to_transforms=opt_net['add_noise'],
+ for_video=opt_net['for_video'])
\ No newline at end of file
diff --git a/codes/models/byol/byol_model_wrapper.py b/codes/models/byol/byol_model_wrapper.py
index 398db439..cb2c7b77 100644
--- a/codes/models/byol/byol_model_wrapper.py
+++ b/codes/models/byol/byol_model_wrapper.py
@@ -10,7 +10,8 @@ from kornia import filters
from torch import nn
from data.byol_attachment import RandomApply
-from utils.util import checkpoint
+from trainer.networks import register_model, create_model
+from utils.util import checkpoint, opt_get
def default(val, def_val):
@@ -269,3 +270,11 @@ class BYOL(nn.Module):
loss = loss_one + loss_two
return loss.mean()
+
+
+@register_model
+def register_byol(opt_net, opt):
+ subnet = create_model(opt, opt_net['subnet'])
+ return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
+ structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
+ do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
\ No newline at end of file
diff --git a/codes/models/byol/byol_structural.py b/codes/models/byol/byol_structural.py
index eb56a5d7..3aeb2602 100644
--- a/codes/models/byol/byol_structural.py
+++ b/codes/models/byol/byol_structural.py
@@ -7,6 +7,7 @@ from torch import nn
from data.byol_attachment import reconstructed_shared_regions
from models.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \
update_moving_average
+from trainer.networks import create_model, register_model
from utils.util import checkpoint
# loss function
@@ -178,4 +179,11 @@ class StructuralBYOL(nn.Module):
def get_projection(self, image):
enc = self.online_encoder(image)
proj = self.online_predictor(enc)
- return enc, proj
\ No newline at end of file
+ return enc, proj
+
+@register_model
+def register_structural_byol(opt_net, opt):
+ subnet = create_model(opt, opt_net['subnet'])
+ return StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
+ pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
+ freeze_until=opt_get(opt_net, ['freeze_until'], 0))
diff --git a/codes/models/discriminator_vgg_arch.py b/codes/models/discriminator_vgg_arch.py
index 3917ddc7..1f98bd0e 100644
--- a/codes/models/discriminator_vgg_arch.py
+++ b/codes/models/discriminator_vgg_arch.py
@@ -1,10 +1,8 @@
import torch
import torch.nn as nn
-from models.RRDBNet_arch import RRDB, RRDBWithBypass
from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN
import torch.nn.functional as F
-from models.SwitchedResidualGenerator_arch import gather_2d
from utils.util import checkpoint
@@ -519,6 +517,7 @@ class RefDiscriminatorVgg128(nn.Module):
def forward(self, x, ref, ref_center_point):
ref = self.ref_head(ref)
ref_center_point = ref_center_point // 16
+ from models.SwitchedResidualGenerator_arch import gather_2d
ref_vector = gather_2d(ref, ref_center_point)
ref_vector = self.ref_linear(ref_vector)
diff --git a/codes/models/glean/__init__.py b/codes/models/glean/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/codes/models/glean/glean.py b/codes/models/glean/glean.py
index d314bc8e..389082af 100644
--- a/codes/models/glean/glean.py
+++ b/codes/models/glean/glean.py
@@ -10,6 +10,7 @@ from models.arch_util import ConvGnLelu
# Produces a convolutional feature (`f`) and a reduced feature map with double the filters.
from models.glean.stylegan2_latent_bank import Stylegan2LatentBank
from models.stylegan.stylegan2_rosinality import EqualLinear
+from trainer.networks import register_model
from utils.util import checkpoint, sequential_checkpoint
@@ -108,3 +109,8 @@ class GleanGenerator(nn.Module):
rrdb_fea, conv_fea, latents = self.encoder(x)
latent_bank_fea = self.latent_bank(conv_fea, latents)
return self.decoder(rrdb_fea, latent_bank_fea)
+
+
+@register_model
+def register_glean(opt_net, opt):
+ return GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])
diff --git a/codes/models/resnet_with_checkpointing.py b/codes/models/resnet_with_checkpointing.py
index 39a5523f..4ab8f3c2 100644
--- a/codes/models/resnet_with_checkpointing.py
+++ b/codes/models/resnet_with_checkpointing.py
@@ -11,6 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
+from trainer.networks import register_model
from utils.util import checkpoint
model_urls = {
@@ -188,3 +189,8 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
+
+
+@register_model
+def register_resnet52(opt_net, opt):
+ return resnet50(pretrained=opt_net['pretrained'])
diff --git a/codes/models/spinenet_arch.py b/codes/models/spinenet_arch.py
index 8bd723ef..8c887fb3 100644
--- a/codes/models/spinenet_arch.py
+++ b/codes/models/spinenet_arch.py
@@ -7,6 +7,7 @@ from torch.nn.init import kaiming_normal
from torchvision.models.resnet import BasicBlock, Bottleneck
from models.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu
+from trainer.networks import register_model
def constant_init(module, val, bias=0):
@@ -359,3 +360,13 @@ class SpinenetWithLogits(SpineNet):
def forward(self, x):
fea = super().forward(x)[self.output_to_attach]
return self.tail(fea)
+
+@register_model
+def register_spinenet(opt_net, opt):
+ return SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
+
+
+@register_model
+def register_spinenet_with_logits(opt_net, opt):
+ return SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
+ in_channels=3, use_input_norm=opt_net['use_input_norm'])
diff --git a/codes/models/srflow/RRDBNet_arch.py b/codes/models/srflow/RRDBNet_arch.py
index 36eee072..e566a0f2 100644
--- a/codes/models/srflow/RRDBNet_arch.py
+++ b/codes/models/srflow/RRDBNet_arch.py
@@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
import models.srflow.module_util as mutil
from models.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
+from trainer.networks import register_model
from utils.util import opt_get
@@ -239,4 +240,19 @@ class RRDBLatentWrapper(nn.Module):
blocklist.append(rrdbResults['last_lr_fea'])
fea = torch.cat(blocklist, dim=1)
fea = self.postprocess(fea)
- return fea
\ No newline at end of file
+ return fea
+
+
+@register_model
+def register_rrdb_latent_wrapper(opt_net, opt):
+ return RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
+ nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
+ blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'],
+ pretrain_rrdb_path=opt_net['pretrain_path'])
+
+
+@register_model
+def register_rrdb_srflow(opt_net, opt):
+ return RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
+ nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
+ initial_conv_stride=opt_net['initial_stride'])
\ No newline at end of file
diff --git a/codes/models/srflow/SRFlowNet_arch.py b/codes/models/srflow/SRFlowNet_arch.py
index 73314407..3002dbdc 100644
--- a/codes/models/srflow/SRFlowNet_arch.py
+++ b/codes/models/srflow/SRFlowNet_arch.py
@@ -8,6 +8,7 @@ from models.srflow.RRDBNet_arch import RRDBNet
from models.srflow.FlowUpsamplerNet import FlowUpsamplerNet
import models.srflow.thops as thops
import models.srflow.flow as flow
+from trainer.networks import register_model
from utils.util import opt_get
@@ -166,3 +167,9 @@ class SRFlowNet(nn.Module):
logdet=logdet)
return x, logdet, lr_enc['out']
+
+
+@register_model
+def register_srflow(opt_net, opt):
+ return SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
+ K=opt_net['K'], opt=opt)
diff --git a/codes/models/srg2_classic.py b/codes/models/srg2_classic.py
index bece2781..b0245023 100644
--- a/codes/models/srg2_classic.py
+++ b/codes/models/srg2_classic.py
@@ -11,6 +11,7 @@ from collections import OrderedDict
from models.SwitchedResidualGenerator_arch import HalvingProcessingBlock, ConfigurableSwitchComputer
from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, MultiConvBlock
from models.switched_conv.switched_conv import BareConvSwitch, AttentionNorm
+from trainer.networks import register_model
from utils.util import checkpoint
@@ -204,3 +205,19 @@ class Interpolate(nn.Module):
def forward(self, x):
return F.interpolate(x, scale_factor=self.factor, mode=self.mode)
+@register_model
+def register_srg2classic(opt_net, opt):
+ return ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'],
+ switch_filters=opt_net['switch_filters'],
+ switch_reductions=opt_net['switch_reductions'],
+ switch_processing_layers=opt_net['switch_processing_layers'],
+ trans_counts=opt_net['trans_counts'],
+ trans_kernel_sizes=opt_net['trans_kernel_sizes'],
+ trans_layers=opt_net['trans_layers'],
+ transformation_filters=opt_net['transformation_filters'],
+ initial_temp=opt_net['temperature'],
+ final_temperature_step=opt_net['temperature_final_step'],
+ heightened_temp_min=opt_net['heightened_temp_min'],
+ heightened_final_step=opt_net['heightened_final_step'],
+ upsample_factor=scale,
+ add_scalable_noise_to_transforms=opt_net['add_noise'])
\ No newline at end of file
diff --git a/codes/models/stylegan/stylegan2_lucidrains.py b/codes/models/stylegan/stylegan2_lucidrains.py
index 9b08854a..d58221e8 100644
--- a/codes/models/stylegan/stylegan2_lucidrains.py
+++ b/codes/models/stylegan/stylegan2_lucidrains.py
@@ -17,6 +17,7 @@ from torch import nn
from torch.autograd import grad as torch_grad
from vector_quantize_pytorch import VectorQuantize
+from trainer.networks import register_model
from utils.util import checkpoint
try:
@@ -893,3 +894,12 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss):
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
return 0
+
+
+@register_model
+def register_stylegan2_lucidrains(opt_net, opt):
+ is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
+ attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
+ return StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
+ style_depth=opt_net['style_depth'], structure_input=is_structured,
+ attn_layers=attn)
diff --git a/codes/models/tecogan/flownet2.py b/codes/models/tecogan/flownet2.py
new file mode 100644
index 00000000..c5c108b3
--- /dev/null
+++ b/codes/models/tecogan/flownet2.py
@@ -0,0 +1,15 @@
+import munch
+import torch
+
+from trainer.networks import register_model
+
+
+@register_model
+def register_flownet2(opt_net):
+ from models.flownet2.models import FlowNet2
+ ld = 'load_path' in opt_net.keys()
+ args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
+ netG = FlowNet2(args)
+ if ld:
+ sd = torch.load(opt_net['load_path'])
+ netG.load_state_dict(sd['state_dict'])
\ No newline at end of file
diff --git a/codes/models/tecogan/teco_resgen.py b/codes/models/tecogan/teco_resgen.py
index 9803df33..2f640148 100644
--- a/codes/models/tecogan/teco_resgen.py
+++ b/codes/models/tecogan/teco_resgen.py
@@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torchvision
+from trainer.networks import register_model
from utils.util import sequential_checkpoint
from models.arch_util import ConvGnSilu, make_layer
@@ -71,3 +72,8 @@ class TecoGen(nn.Module):
def get_debug_values(self, step, net_name):
return {'branch_std': self.join.std()}
+
+
+@register_model
+def register_tecogen(opt_net, opt):
+ return TecoGen(opt_net['nf'], opt_net['scale'])
\ No newline at end of file
diff --git a/codes/models/transformers/igpt/gpt2.py b/codes/models/transformers/igpt/gpt2.py
index 552185e3..21a7e4fd 100644
--- a/codes/models/transformers/igpt/gpt2.py
+++ b/codes/models/transformers/igpt/gpt2.py
@@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from trainer.injectors import Injector
+from trainer.networks import register_model
from utils.util import checkpoint
@@ -147,3 +148,8 @@ class iGPT2(nn.Module):
return logits, x
+
+@register_model
+def register_igpt2(opt_net, opt):
+ return iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2,
+ opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py
index ffec11b5..8298091f 100644
--- a/codes/trainer/ExtensibleTrainer.py
+++ b/codes/trainer/ExtensibleTrainer.py
@@ -57,7 +57,7 @@ class ExtensibleTrainer(BaseModel):
new_net = None
if net['type'] == 'generator':
if new_net is None:
- new_net = networks.define_G(opt, net, opt['scale']).to(self.device)
+ new_net = networks.create_model(opt, net, opt['scale']).to(self.device)
self.netsG[name] = new_net
elif net['type'] == 'discriminator':
if new_net is None:
diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py
index 7c67eb67..422d6f91 100644
--- a/codes/trainer/networks.py
+++ b/codes/trainer/networks.py
@@ -1,140 +1,79 @@
import functools
+import importlib
import logging
+import pkgutil
+import sys
from collections import OrderedDict
+from inspect import isfunction, getmembers
-import munch
import torch
import torchvision
-from munch import munchify
-import models.stylegan.stylegan2_lucidrains as stylegan2
-import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
-import models.RRDBNet_arch as RRDBNet_arch
-import models.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.discriminator_vgg_arch as SRGAN_arch
import models.feature_arch as feature_arch
-from models import srg2_classic
+import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
-from models.tecogan.teco_resgen import TecoGen
-from utils.util import opt_get
logger = logging.getLogger('base')
-# Generator
-def define_G(opt, opt_net, scale=None):
- if scale is None:
- scale = opt['scale']
- which_model = opt_net['which_model_G']
- if 'RRDBNet' in which_model:
- if which_model == 'RRDBNetBypass':
- block = RRDBNet_arch.RRDBWithBypass
- elif which_model == 'RRDBNetLambda':
- from models.lambda_rrdb import LambdaRRDB
- block = LambdaRRDB
- else:
- block = RRDBNet_arch.RRDB
- additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
- output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
- gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
- initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
- netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
- mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
- output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc,
- initial_stride=initial_stride)
- elif which_model == "ConfigurableSwitchedResidualGenerator2":
- netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
- switch_reductions=opt_net['switch_reductions'],
- switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
- trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
- transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
- initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
- heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
- upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
- for_video=opt_net['for_video'])
- elif which_model == "srg2classic":
- netG = srg2_classic.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
- switch_reductions=opt_net['switch_reductions'],
- switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
- trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
- transformation_filters=opt_net['transformation_filters'],
- initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
- heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
- upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
- elif which_model == "flownet2":
- from models.flownet2 import FlowNet2
- ld = 'load_path' in opt_net.keys()
- args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
- netG = FlowNet2(args)
- if ld:
- sd = torch.load(opt_net['load_path'])
- netG.load_state_dict(sd['state_dict'])
- elif which_model == "backbone_encoder":
- netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
- elif which_model == "backbone_encoder_no_ref":
- netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet'])
- elif which_model == "backbone_encoder_no_head":
- netG = SwitchedGen_arch.BackboneSpinenetNoHead()
- elif which_model == "backbone_resnet":
- netG = SwitchedGen_arch.BackboneResnet()
- elif which_model == "tecogen":
- netG = TecoGen(opt_net['nf'], opt_net['scale'])
- elif which_model == 'stylegan2':
- is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
- attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
- netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
- style_depth=opt_net['style_depth'], structure_input=is_structured,
- attn_layers=attn)
- elif which_model == 'srflow':
- from models.srflow import SRFlowNet_arch
- netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
- K=opt_net['K'], opt=opt)
- elif which_model == 'rrdb_latent_wrapper':
- from models.srflow.RRDBNet_arch import RRDBLatentWrapper
- netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
- nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
- blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
- elif which_model == 'rrdb_centipede':
- output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
- netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
- mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
- headless=True, output_mode=output_mode)
- elif which_model == 'rrdb_srflow':
- from models.srflow.RRDBNet_arch import RRDBNet
- netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
- nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
- initial_conv_stride=opt_net['initial_stride'])
- elif which_model == 'igpt2':
- from models.transformers.igpt.gpt2 import iGPT2
- netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
- elif which_model == 'byol':
- from models.byol.byol_model_wrapper import BYOL
- subnet = define_G(opt, opt_net['subnet'])
- netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
- structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
- do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
- elif which_model == 'structural_byol':
- from models.byol.byol_structural import StructuralBYOL
- subnet = define_G(opt, opt_net['subnet'])
- netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
- pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
- freeze_until=opt_get(opt_net, ['freeze_until'], 0))
- elif which_model == 'spinenet':
- from models.spinenet_arch import SpineNet
- netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
- elif which_model == 'spinenet_with_logits':
- from models.spinenet_arch import SpinenetWithLogits
- netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
- in_channels=3, use_input_norm=opt_net['use_input_norm'])
- elif which_model == 'resnet52':
- from models.resnet_with_checkpointing import resnet50
- netG = resnet50(pretrained=opt_net['pretrained'])
- elif which_model == 'glean':
- from models.glean.glean import GleanGenerator
- netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])
+class RegisteredModelNameError(Exception):
+ def __init__(self, name_error):
+ super().__init__(f'Registered DLAS modules must start with `register_`. Incorrect registration: {name_error}')
+
+
+# Decorator that allows API clients to show DLAS how to build a nn.Module from an opt dict.
+# Functions with this decorator should have a specific naming format:
+# `register_` where is the name that will be used in configuration files to reference this model.
+# Functions with this decorator are expected to take a single argument:
+# - opt: A dict with the configuration options for building the module.
+# They should return:
+# - A torch.nn.Module object for the model being defined.
+def register_model(func):
+ if func.__name__.startswith("register_"):
+ func._dlas_model_name = func.__name__[9:]
+ assert func._dlas_model_name
else:
- raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
- return netG
+ raise RegisteredModelNameError(func.__name__)
+ func._dlas_registered_model = True
+ return func
+
+
+def find_registered_model_fns(base_path='models'):
+ found_fns = {}
+ module_iter = pkgutil.walk_packages([base_path])
+ for mod in module_iter:
+ if mod.ispkg:
+ EXCLUSION_LIST = ['flownet2']
+ if mod.name not in EXCLUSION_LIST:
+ found_fns.update(find_registered_model_fns(f'{base_path}/{mod.name}'))
+ else:
+ mod_name = f'{base_path}/{mod.name}'.replace('/', '.')
+ importlib.import_module(mod_name)
+ for mod_fn in getmembers(sys.modules[mod_name], isfunction):
+ if hasattr(mod_fn[1], "_dlas_registered_model"):
+ found_fns[mod_fn[1]._dlas_model_name] = mod_fn[1]
+ return found_fns
+
+
+class CreateModelError(Exception):
+ def __init__(self, name, available):
+ super().__init__(f'Could not find the specified model name: {name}. Tip: If your model is in a'
+ f' subdirectory, that directory must contain an __init__.py to be scanned. Available models:'
+ f'{available}')
+
+
+def create_model(opt, opt_net, scale=None):
+ which_model = opt_net['which_model']
+ # For backwards compatibility.
+ if not which_model:
+ which_model = opt_net['which_model_G']
+ if not which_model:
+ which_model = opt_net['which_model_D']
+ registered_fns = find_registered_model_fns()
+ if which_model not in registered_fns.keys():
+ raise CreateModelError(which_model, list(registered_fns.keys()))
+ return registered_fns[which_model](opt_net, opt)
class GradDiscWrapper(torch.nn.Module):
diff --git a/codes/utils/distill_torchscript.py b/codes/utils/distill_torchscript.py
index a3b1972f..819cb358 100644
--- a/codes/utils/distill_torchscript.py
+++ b/codes/utils/distill_torchscript.py
@@ -2,7 +2,7 @@ import argparse
import functools
import torch
from utils import options as option
-from trainer.networks import define_G
+from trainer.networks import create_model
class TracedModule:
@@ -96,7 +96,7 @@ if __name__ == "__main__":
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
- netG = define_G(opt)
+ netG = create_model(opt)
dummyInput = torch.rand(1,3,32,32)
mode = 'onnx'