diff --git a/.idea/vcs.xml b/.idea/vcs.xml index fc78291c..8d59ed0d 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -3,6 +3,7 @@ + \ No newline at end of file diff --git a/codes/models/ProgressiveSrg_arch.py b/codes/models/ProgressiveSrg_arch.py deleted file mode 100644 index 7e5c9dde..00000000 --- a/codes/models/ProgressiveSrg_arch.py +++ /dev/null @@ -1,275 +0,0 @@ -import models.SwitchedResidualGenerator_arch as srg -import torch -import torch.nn as nn -from switched_conv.switched_conv_util import save_attention_to_image -from switched_conv.switched_conv import compute_attention_specificity -from models.arch_util import ConvGnLelu, ExpansionBlock, MultiConvBlock -import functools -import torch.nn.functional as F - -# Some notes about this new architecture: -# 1) Discriminator is going to need to get update_for_step() called. -# 2) Not sure if pixgan part of discriminator is going to work properly, make sure to test at multiple add levels. -# 3) Also not sure if growth modules will be properly saved/trained, be sure to test this. -# 4) start_step will need to get set properly when constructing these models, even when resuming - OR another method needs to be added to resume properly. - -class GrowingSRGBase(nn.Module): - def __init__(self, progressive_step_schedule, switch_reductions, growth_fade_in_steps, switch_filters, switch_processing_layers, trans_counts, - trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, upsample_factor=1, - add_scalable_noise_to_transforms=False, start_step=0): - super(GrowingSRGBase, self).__init__() - switches = [] - self.initial_conv = ConvGnLelu(3, transformation_filters, norm=False, activation=False, bias=True) - self.upconv1 = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.upconv2 = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.hr_conv = ConvGnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.final_conv = ConvGnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) - - self.switch_filters = switch_filters - self.switch_processing_layers = switch_processing_layers - self.trans_layers = trans_layers - self.transformation_filters = transformation_filters - self.progressive_schedule = progressive_step_schedule - self.switch_reductions = switch_reductions # This lists the reductions for all switches (even ones not activated yet). - self.growth_fade_in_per_step = 1 / growth_fade_in_steps - self.transformation_counts = trans_counts - self.init_temperature = initial_temp - self.final_temperature_step = final_temperature_step - self.attentions = None - self.upsample_factor = upsample_factor - self.add_noise_to_transform = add_scalable_noise_to_transforms - self.start_step = start_step - self.latest_step = start_step - self.fades = [] - self.counter = 0 - assert self.upsample_factor == 2 or self.upsample_factor == 4 - - switches = [] - for i, (step, reductions) in enumerate(zip(progressive_step_schedule, switch_reductions)): - multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters, - reductions, self.switch_processing_layers, self.transformation_counts) - pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False, - bias=False, weight_init_factor=.1) - transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5), - self.transformation_filters, kernel_size=3, depth=self.trans_layers, - weight_init_factor=.1) - switches.append(srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, - transform_block=transform_fn, - transform_count=self.transformation_counts, init_temp=self.init_temperature, - add_scalable_noise_to_transforms=self.add_noise_to_transform, - attention_norm=False)) - self.progressive_switches = nn.ModuleList(switches) - - def get_param_groups(self): - param_groups = [] - base_param_group = [] - for k, v in self.named_parameters(): - if "progressive_switches" not in k and v.requires_grad: - base_param_group.append(v) - param_groups.append({'params': base_param_group}) - for i, sw in enumerate(self.progressive_switches): - sw_param_group = [] - for k, v in sw.named_parameters(): - if v.requires_grad: - sw_param_group.append(v) - param_groups.append({'params': sw_param_group}) - return param_groups - - # This is a hacky way of modifying the underlying model while training. Since changing the model means changing - # the optimizer and the scheduler, these things are fed in. For ProgressiveSrg, this function adds an additional -# switch to the end of the chain with depth=3 and an online time set at the end fo the function. - def update_model(self, opt, sched): - multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters, - 3, self.switch_processing_layers, self.transformation_counts) - pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False, - bias=False, weight_init_factor=.1) - transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5), - self.transformation_filters, kernel_size=3, depth=self.trans_layers, - weight_init_factor=.1) - new_sw = srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, - transform_block=transform_fn, - transform_count=self.transformation_counts, init_temp=self.init_temperature, - add_scalable_noise_to_transforms=self.add_noise_to_transform, - attention_norm=False).to('cuda') - self.progressive_switches.append(new_sw) - new_sw_param_group = [] - for k, v in new_sw.named_parameters(): - if v.requires_grad: - new_sw_param_group.append(v) - opt.add_param_group({'params': new_sw_param_group}) - self.progressive_schedule.append(150000) - sched.group_starts.append(150000) - - def get_progressive_starts(self): - # The base param group starts at step 0, the rest are defined via progressive_switches. - return [0] + self.progressive_schedule - - # This method turns requires_grad on and off for different switches, allowing very large models to be trained while - # using less memory. When used in conjunction with gradient accumulation, it becomes a form of model parallelism. - # controls the proportion of switches that are enabled. 1/groups will be enabled. - # Switches that are younger than 40000 steps are not eligible to be turned off. - def do_switched_grad(self, groups=1): - # If requires_grad is already disabled, don't bother. - if not self.initial_conv.conv.weight.requires_grad or groups == 1: - return - self.counter = (self.counter + 1) % groups - enabled = [] - for i, sw in enumerate(self.progressive_switches): - if self.latest_step - self.progressive_schedule[i] > 40000 and i % groups != self.counter: - for p in sw.parameters(): - p.requires_grad = False - else: - enabled.append(i) - for p in sw.parameters(): - p.requires_grad = True - - def forward(self, x): - self.do_switched_grad(2) - - x = self.initial_conv(x) - - self.attentions = [] - self.fades = [] - self.enabled_switches = 0 - for i, sw in enumerate(self.progressive_switches): - fade_in = 1 if self.progressive_schedule[i] == 0 else 0 - if self.latest_step > 0 and self.progressive_schedule[i] != 0: - switch_age = self.latest_step - self.progressive_schedule[i] - fade_in = min(1, switch_age * self.growth_fade_in_per_step) - - if fade_in > 0: - self.enabled_switches += 1 - x, att = sw.forward(x, True, fixed_scale=fade_in) - self.attentions.append(att) - self.fades.append(fade_in) - - x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) - if self.upsample_factor > 2: - x = F.interpolate(x, scale_factor=2, mode="nearest") - x = self.upconv2(x) - x = self.final_conv(self.hr_conv(x)) - return x, x - - def update_for_step(self, step, experiments_path='.'): - self.latest_step = step + self.start_step - - # Set the temperature of the switches, per-layer. - for i, (first_step, sw) in enumerate(zip(self.progressive_schedule, self.progressive_switches)): - temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step - sw.set_temperature(min(self.init_temperature, - max(self.init_temperature - temp_loss_per_step * (step - first_step), 1))) - - # Save attention images. - if self.attentions is not None and step % 50 == 0: - [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))] - - def get_debug_values(self, step): - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - val["switch_%i_temperature" % (i,)] = self.progressive_switches[i].switch.temperature - for i, f in enumerate(self.fades): - val["switch_%i_fade" % (i,)] = f - val["enabled_switches"] = self.enabled_switches - return val - -class DiscriminatorDownsample(nn.Module): - def __init__(self, base_filters, end_filters): - self.conv0 = ConvGnLelu(base_filters, end_filters, kernel_size=3, bias=False) - self.conv1 = ConvGnLelu(end_filters, end_filters, kernel_size=3, stride=2, bias=False) - - def forward(self, x): - return self.conv1(self.conv0(x)) - - -class DiscriminatorUpsample(nn.Module): - def __init__(self, base_filters, end_filters): - self.up = ExpansionBlock(base_filters, end_filters, block=ConvGnLelu) - self.proc = ConvGnLelu(end_filters, end_filters, bias=False) - self.collapse = ConvGnLelu(end_filters, 1, bias=True, norm=False, activation=False) - - def forward(self, x, ff): - x = self.up1(x, ff) - return x, self.collapse1(self.proc1(x)) - - -class GrowingUnetDiscBase(nn.Module): - def __init__(self, nf, growing_schedule, growth_fade_in_steps, start_step=0): - super(GrowingUnetDiscBase, self).__init__() - # [64, 128, 128] - self.conv0_0 = ConvGnLelu(3, nf, kernel_size=3, bias=True, activation=False) - self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False) - # [64, 64, 64] - self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False) - self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False) - - self.down_base = DiscriminatorDownsample(nf * 2, nf * 4) - self.up_base = DiscriminatorUpsample(nf * 4, nf * 2) - - self.progressive_schedule = growing_schedule - self.growth_fade_in_per_step = 1 / growth_fade_in_steps - self.pnf = nf * 4 - self.downsamples = nn.ModuleList([]) - self.upsamples = nn.ModuleList([]) - - for i, step in enumerate(growing_schedule): - if step >= start_step: - self.add_layer(i + 1) - - def add_layer(self): - self.downsamples.append(DiscriminatorDownsample(self.pnf, self.pnf)) - self.upsamples.append(DiscriminatorUpsample(self.pnf, self.pnf)) - - def update_for_step(self, step): - self.latest_step = step - - # Add any new layers as spelled out by the schedule. - if step != 0: - for i, s in enumerate(self.progressive_schedule): - if s == step: - self.add_layer(i + 1) - - def forward(self, x, output_feature_vector=False): - x = self.conv0_0(x) - x = self.conv0_1(x) - x = self.conv1_0(x) - x = self.conv1_1(x) - base_fea = self.down_base(x) - x = base_fea - - skips = [] - for down in self.downsamples: - x = down(x) - skips.append(x) - - losses = [] - for i, up in enumerate(self.upsamples): - j = i + 1 - x, loss = up(x, skips[-j]) - losses.append(loss) - - # This variant averages the outputs of the U-net across the upsamples, weighting the contribution - # to the average less for newly growing levels. - _, base_loss = self.up_base(x, base_fea) - res = base_loss.shape[2:] - - mean_weight = 1 - for i, l in enumerate(losses): - fade_in = 1 - if self.latest_step > 0 and self.progressive_schedule[i] != 0: - disc_age = self.latest_step - self.progressive_schedule[i] - fade_in = min(1, disc_age * self.growth_fade_in_per_step) - mean_weight += fade_in - base_loss += F.interpolate(l, size=res, mode="bilinear", align_corners=False) * fade_in - base_loss /= mean_weight - - return base_loss.view(-1, 1) - - def pixgan_parameters(self): - return 1, 4 \ No newline at end of file diff --git a/codes/models/RRDBNet_arch.py b/codes/models/RRDBNet_arch.py index 6a43f158..ba4d576f 100644 --- a/codes/models/RRDBNet_arch.py +++ b/codes/models/RRDBNet_arch.py @@ -6,6 +6,7 @@ import torch.nn.functional as F import torchvision from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu +from trainer.networks import register_model from utils.util import checkpoint, sequential_checkpoint @@ -370,3 +371,27 @@ class RRDBDiscriminator(nn.Module): if self.pred_ is not None: self.pred_ = F.sigmoid(self.pred_) torchvision.utils.save_image(self.pred_.cpu().float(), os.path.join(path, "%i_predictions.png" % (step,))) + + +@register_model +def register_RRDBNetBypass(opt_net, opt): + additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not' + output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' + gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32 + initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1 + return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], + mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, + output_mode=output_mode, body_block=RRDBWithBypass, scale=opt_net['scale'], growth_channels=gc, + initial_stride=initial_stride) + + +@register_model +def register_RRDBNet(opt_net, opt): + additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not' + output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' + gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32 + initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1 + return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], + mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, + output_mode=output_mode, body_block=RRDB, scale=opt_net['scale'], growth_channels=gc, + initial_stride=initial_stride) \ No newline at end of file diff --git a/codes/models/SwitchedResidualGenerator_arch.py b/codes/models/SwitchedResidualGenerator_arch.py index feb20733..433f3621 100644 --- a/codes/models/SwitchedResidualGenerator_arch.py +++ b/codes/models/SwitchedResidualGenerator_arch.py @@ -1,17 +1,19 @@ -import torch -from torch import nn -from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm -import torch.nn.functional as F import functools -from collections import OrderedDict -from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \ - SiLU, UpconvBlock, ReferenceJoinBlock -from models.switched_conv.switched_conv_util import save_attention_to_image_rgb import os -from models.spinenet_arch import SpineNet +from collections import OrderedDict + +import torch +import torch.nn.functional as F import torchvision +from torch import nn + +from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, MultiConvBlock +from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm +from models.switched_conv.switched_conv_util import save_attention_to_image_rgb +from trainer.networks import register_model from utils.util import checkpoint + # VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation # Doubles the input filter count. class HalvingProcessingBlock(nn.Module): @@ -261,142 +263,6 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): return val -# This class encapsulates an encoder based on an object detection network backbone whose purpose is to generated a -# structured embedding encoding what is in an image patch. This embedding can then be used to perform structured -# alterations to the underlying image. -# -# Caveat: Since this uses a pre-defined (and potentially pre-trained) SpineNet backbone, it has a minimum-supported -# image size, which is 128x128. In order to use 64x64 patches, you must set interpolate_first=True. though this will -# degrade quality. -class BackboneEncoder(nn.Module): - def __init__(self, interpolate_first=True, pretrained_backbone=None): - super(BackboneEncoder, self).__init__() - self.interpolate_first = interpolate_first - - # Uses dual spinenets, one for the input patch and the other for the reference image. - self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True) - self.ref_spine = SpineNet('49', in_channels=3, use_input_norm=True) - - self.merge_process1 = ConvGnSilu(512, 512, kernel_size=1, activation=True, norm=False, bias=True) - self.merge_process2 = ConvGnSilu(512, 384, kernel_size=1, activation=True, norm=True, bias=False) - self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True) - - if pretrained_backbone is not None: - loaded_params = torch.load(pretrained_backbone) - self.ref_spine.load_state_dict(loaded_params['state_dict'], strict=True) - self.patch_spine.load_state_dict(loaded_params['state_dict'], strict=True) - - # Returned embedding will have been reduced in size by a factor of 8 (4 if interpolate_first=True). - # Output channels are always 256. - # ex, 64x64 input with interpolate_first=True will result in tensor of shape [bx256x16x16] - def forward(self, x, ref, ref_center_point): - if self.interpolate_first: - x = F.interpolate(x, scale_factor=2, mode="bicubic") - # Don't interpolate ref - assume it is fed in at the proper resolution. - # ref = F.interpolate(ref, scale_factor=2, mode="bicubic") - - # [ref] will have a 'mask' channel which we cannot use with pretrained spinenet. - ref = ref[:, :3, :, :] - ref_emb = self.ref_spine(ref)[0] - ref_code = gather_2d(ref_emb, ref_center_point // 8) # Divide by 8 to bring the center point to the correct location. - - patch = self.patch_spine(x)[0] - ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3]) - combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1)) - combined = self.merge_process2(combined) - combined = self.merge_process3(combined) - - return combined - - -class BackboneEncoderNoRef(nn.Module): - def __init__(self, interpolate_first=True, pretrained_backbone=None): - super(BackboneEncoderNoRef, self).__init__() - self.interpolate_first = interpolate_first - - self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True) - - if pretrained_backbone is not None: - loaded_params = torch.load(pretrained_backbone) - self.patch_spine.load_state_dict(loaded_params['state_dict'], strict=True) - - # Returned embedding will have been reduced in size by a factor of 8 (4 if interpolate_first=True). - # Output channels are always 256. - # ex, 64x64 input with interpolate_first=True will result in tensor of shape [bx256x16x16] - def forward(self, x): - if self.interpolate_first: - x = F.interpolate(x, scale_factor=2, mode="bicubic") - - patch = self.patch_spine(x)[0] - return patch - - -class BackboneSpinenetNoHead(nn.Module): - def __init__(self): - super(BackboneSpinenetNoHead, self).__init__() - # Uses dual spinenets, one for the input patch and the other for the reference image. - self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=False, double_reduce_early=False) - self.ref_spine = SpineNet('49', in_channels=4, use_input_norm=False, double_reduce_early=False) - - self.merge_process1 = ConvGnSilu(512, 512, kernel_size=1, activation=True, norm=False, bias=True) - self.merge_process2 = ConvGnSilu(512, 384, kernel_size=1, activation=True, norm=True, bias=False) - self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True) - - def forward(self, x, ref, ref_center_point): - ref_emb = self.ref_spine(ref)[0] - ref_code = gather_2d(ref_emb, ref_center_point // 4) # Divide by 8 to bring the center point to the correct location. - - patch = self.patch_spine(x)[0] - ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3]) - combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1)) - combined = self.merge_process2(combined) - combined = self.merge_process3(combined) - return combined - - -class ResBlock(nn.Module): - def __init__(self, nf, downsample): - super(ResBlock, self).__init__() - nf_int = nf * 2 - nf_out = nf * 2 if downsample else nf - stride = 2 if downsample else 1 - self.c1 = ConvGnSilu(nf, nf_int, kernel_size=3, bias=False, activation=True, norm=True) - self.c2 = ConvGnSilu(nf_int, nf_int, stride=stride, kernel_size=3, bias=False, activation=True, norm=True) - self.c3 = ConvGnSilu(nf_int, nf_out, kernel_size=3, bias=False, activation=False, norm=True) - if downsample: - self.downsample = ConvGnSilu(nf, nf_out, kernel_size=1, stride=stride, bias=False, activation=False, norm=True) - else: - self.downsample = None - self.act = SiLU() - - def forward(self, x): - identity = x - branch = self.c1(x) - branch = self.c2(branch) - branch = self.c3(branch) - - if self.downsample: - identity = self.downsample(identity) - return self.act(identity + branch) - - -class BackboneResnet(nn.Module): - def __init__(self): - super(BackboneResnet, self).__init__() - self.initial_conv = ConvGnSilu(3, 64, kernel_size=7, bias=True, activation=False, norm=False) - self.sequence = nn.Sequential( - ResBlock(64, downsample=False), - ResBlock(64, downsample=True), - ResBlock(128, downsample=False), - ResBlock(128, downsample=True), - ResBlock(256, downsample=False), - ResBlock(256, downsample=False)) - - def forward(self, x): - fea = self.initial_conv(x) - return self.sequence(fea) - - # Computes a linear latent by performing processing on the reference image and returning the filters of a single point, # which should be centered on the image patch being processed. # @@ -705,4 +571,24 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): for i in range(len(means)): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] - return val \ No newline at end of file + return val + +@register_model +def register_ConfigurableSwitchedResidualGenerator2(opt_net, opt): + return ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], + switch_filters=opt_net['switch_filters'], + switch_reductions=opt_net['switch_reductions'], + switch_processing_layers=opt_net[ + 'switch_processing_layers'], + trans_counts=opt_net['trans_counts'], + trans_kernel_sizes=opt_net['trans_kernel_sizes'], + trans_layers=opt_net['trans_layers'], + transformation_filters=opt_net['transformation_filters'], + attention_norm=opt_net['attention_norm'], + initial_temp=opt_net['temperature'], + final_temperature_step=opt_net['temperature_final_step'], + heightened_temp_min=opt_net['heightened_temp_min'], + heightened_final_step=opt_net['heightened_final_step'], + upsample_factor=scale, + add_scalable_noise_to_transforms=opt_net['add_noise'], + for_video=opt_net['for_video']) \ No newline at end of file diff --git a/codes/models/byol/byol_model_wrapper.py b/codes/models/byol/byol_model_wrapper.py index 398db439..cb2c7b77 100644 --- a/codes/models/byol/byol_model_wrapper.py +++ b/codes/models/byol/byol_model_wrapper.py @@ -10,7 +10,8 @@ from kornia import filters from torch import nn from data.byol_attachment import RandomApply -from utils.util import checkpoint +from trainer.networks import register_model, create_model +from utils.util import checkpoint, opt_get def default(val, def_val): @@ -269,3 +270,11 @@ class BYOL(nn.Module): loss = loss_one + loss_two return loss.mean() + + +@register_model +def register_byol(opt_net, opt): + subnet = create_model(opt, opt_net['subnet']) + return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], + structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False), + do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False)) \ No newline at end of file diff --git a/codes/models/byol/byol_structural.py b/codes/models/byol/byol_structural.py index eb56a5d7..3aeb2602 100644 --- a/codes/models/byol/byol_structural.py +++ b/codes/models/byol/byol_structural.py @@ -7,6 +7,7 @@ from torch import nn from data.byol_attachment import reconstructed_shared_regions from models.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \ update_moving_average +from trainer.networks import create_model, register_model from utils.util import checkpoint # loss function @@ -178,4 +179,11 @@ class StructuralBYOL(nn.Module): def get_projection(self, image): enc = self.online_encoder(image) proj = self.online_predictor(enc) - return enc, proj \ No newline at end of file + return enc, proj + +@register_model +def register_structural_byol(opt_net, opt): + subnet = create_model(opt, opt_net['subnet']) + return StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], + pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]), + freeze_until=opt_get(opt_net, ['freeze_until'], 0)) diff --git a/codes/models/discriminator_vgg_arch.py b/codes/models/discriminator_vgg_arch.py index 3917ddc7..1f98bd0e 100644 --- a/codes/models/discriminator_vgg_arch.py +++ b/codes/models/discriminator_vgg_arch.py @@ -1,10 +1,8 @@ import torch import torch.nn as nn -from models.RRDBNet_arch import RRDB, RRDBWithBypass from models.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN import torch.nn.functional as F -from models.SwitchedResidualGenerator_arch import gather_2d from utils.util import checkpoint @@ -519,6 +517,7 @@ class RefDiscriminatorVgg128(nn.Module): def forward(self, x, ref, ref_center_point): ref = self.ref_head(ref) ref_center_point = ref_center_point // 16 + from models.SwitchedResidualGenerator_arch import gather_2d ref_vector = gather_2d(ref, ref_center_point) ref_vector = self.ref_linear(ref_vector) diff --git a/codes/models/glean/__init__.py b/codes/models/glean/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/glean/glean.py b/codes/models/glean/glean.py index d314bc8e..389082af 100644 --- a/codes/models/glean/glean.py +++ b/codes/models/glean/glean.py @@ -10,6 +10,7 @@ from models.arch_util import ConvGnLelu # Produces a convolutional feature (`f`) and a reduced feature map with double the filters. from models.glean.stylegan2_latent_bank import Stylegan2LatentBank from models.stylegan.stylegan2_rosinality import EqualLinear +from trainer.networks import register_model from utils.util import checkpoint, sequential_checkpoint @@ -108,3 +109,8 @@ class GleanGenerator(nn.Module): rrdb_fea, conv_fea, latents = self.encoder(x) latent_bank_fea = self.latent_bank(conv_fea, latents) return self.decoder(rrdb_fea, latent_bank_fea) + + +@register_model +def register_glean(opt_net, opt): + return GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan']) diff --git a/codes/models/resnet_with_checkpointing.py b/codes/models/resnet_with_checkpointing.py index 39a5523f..4ab8f3c2 100644 --- a/codes/models/resnet_with_checkpointing.py +++ b/codes/models/resnet_with_checkpointing.py @@ -11,6 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] +from trainer.networks import register_model from utils.util import checkpoint model_urls = { @@ -188,3 +189,8 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs): kwargs['width_per_group'] = 64 * 2 return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + + +@register_model +def register_resnet52(opt_net, opt): + return resnet50(pretrained=opt_net['pretrained']) diff --git a/codes/models/spinenet_arch.py b/codes/models/spinenet_arch.py index 8bd723ef..8c887fb3 100644 --- a/codes/models/spinenet_arch.py +++ b/codes/models/spinenet_arch.py @@ -7,6 +7,7 @@ from torch.nn.init import kaiming_normal from torchvision.models.resnet import BasicBlock, Bottleneck from models.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu +from trainer.networks import register_model def constant_init(module, val, bias=0): @@ -359,3 +360,13 @@ class SpinenetWithLogits(SpineNet): def forward(self, x): fea = super().forward(x)[self.output_to_attach] return self.tail(fea) + +@register_model +def register_spinenet(opt_net, opt): + return SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm']) + + +@register_model +def register_spinenet_with_logits(opt_net, opt): + return SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'], + in_channels=3, use_input_norm=opt_net['use_input_norm']) diff --git a/codes/models/srflow/RRDBNet_arch.py b/codes/models/srflow/RRDBNet_arch.py index 36eee072..e566a0f2 100644 --- a/codes/models/srflow/RRDBNet_arch.py +++ b/codes/models/srflow/RRDBNet_arch.py @@ -4,6 +4,7 @@ import torch.nn as nn import torch.nn.functional as F import models.srflow.module_util as mutil from models.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu +from trainer.networks import register_model from utils.util import opt_get @@ -239,4 +240,19 @@ class RRDBLatentWrapper(nn.Module): blocklist.append(rrdbResults['last_lr_fea']) fea = torch.cat(blocklist, dim=1) fea = self.postprocess(fea) - return fea \ No newline at end of file + return fea + + +@register_model +def register_rrdb_latent_wrapper(opt_net, opt): + return RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'], + blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], + pretrain_rrdb_path=opt_net['pretrain_path']) + + +@register_model +def register_rrdb_srflow(opt_net, opt): + return RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], + initial_conv_stride=opt_net['initial_stride']) \ No newline at end of file diff --git a/codes/models/srflow/SRFlowNet_arch.py b/codes/models/srflow/SRFlowNet_arch.py index 73314407..3002dbdc 100644 --- a/codes/models/srflow/SRFlowNet_arch.py +++ b/codes/models/srflow/SRFlowNet_arch.py @@ -8,6 +8,7 @@ from models.srflow.RRDBNet_arch import RRDBNet from models.srflow.FlowUpsamplerNet import FlowUpsamplerNet import models.srflow.thops as thops import models.srflow.flow as flow +from trainer.networks import register_model from utils.util import opt_get @@ -166,3 +167,9 @@ class SRFlowNet(nn.Module): logdet=logdet) return x, logdet, lr_enc['out'] + + +@register_model +def register_srflow(opt_net, opt): + return SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], + K=opt_net['K'], opt=opt) diff --git a/codes/models/srg2_classic.py b/codes/models/srg2_classic.py index bece2781..b0245023 100644 --- a/codes/models/srg2_classic.py +++ b/codes/models/srg2_classic.py @@ -11,6 +11,7 @@ from collections import OrderedDict from models.SwitchedResidualGenerator_arch import HalvingProcessingBlock, ConfigurableSwitchComputer from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, MultiConvBlock from models.switched_conv.switched_conv import BareConvSwitch, AttentionNorm +from trainer.networks import register_model from utils.util import checkpoint @@ -204,3 +205,19 @@ class Interpolate(nn.Module): def forward(self, x): return F.interpolate(x, scale_factor=self.factor, mode=self.mode) +@register_model +def register_srg2classic(opt_net, opt): + return ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], + switch_filters=opt_net['switch_filters'], + switch_reductions=opt_net['switch_reductions'], + switch_processing_layers=opt_net['switch_processing_layers'], + trans_counts=opt_net['trans_counts'], + trans_kernel_sizes=opt_net['trans_kernel_sizes'], + trans_layers=opt_net['trans_layers'], + transformation_filters=opt_net['transformation_filters'], + initial_temp=opt_net['temperature'], + final_temperature_step=opt_net['temperature_final_step'], + heightened_temp_min=opt_net['heightened_temp_min'], + heightened_final_step=opt_net['heightened_final_step'], + upsample_factor=scale, + add_scalable_noise_to_transforms=opt_net['add_noise']) \ No newline at end of file diff --git a/codes/models/stylegan/stylegan2_lucidrains.py b/codes/models/stylegan/stylegan2_lucidrains.py index 9b08854a..d58221e8 100644 --- a/codes/models/stylegan/stylegan2_lucidrains.py +++ b/codes/models/stylegan/stylegan2_lucidrains.py @@ -17,6 +17,7 @@ from torch import nn from torch.autograd import grad as torch_grad from vector_quantize_pytorch import VectorQuantize +from trainer.networks import register_model from utils.util import checkpoint try: @@ -893,3 +894,12 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss): self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length) return 0 + + +@register_model +def register_stylegan2_lucidrains(opt_net, opt): + is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False + attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] + return StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], + style_depth=opt_net['style_depth'], structure_input=is_structured, + attn_layers=attn) diff --git a/codes/models/tecogan/flownet2.py b/codes/models/tecogan/flownet2.py new file mode 100644 index 00000000..c5c108b3 --- /dev/null +++ b/codes/models/tecogan/flownet2.py @@ -0,0 +1,15 @@ +import munch +import torch + +from trainer.networks import register_model + + +@register_model +def register_flownet2(opt_net): + from models.flownet2.models import FlowNet2 + ld = 'load_path' in opt_net.keys() + args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld}) + netG = FlowNet2(args) + if ld: + sd = torch.load(opt_net['load_path']) + netG.load_state_dict(sd['state_dict']) \ No newline at end of file diff --git a/codes/models/tecogan/teco_resgen.py b/codes/models/tecogan/teco_resgen.py index 9803df33..2f640148 100644 --- a/codes/models/tecogan/teco_resgen.py +++ b/codes/models/tecogan/teco_resgen.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torchvision +from trainer.networks import register_model from utils.util import sequential_checkpoint from models.arch_util import ConvGnSilu, make_layer @@ -71,3 +72,8 @@ class TecoGen(nn.Module): def get_debug_values(self, step, net_name): return {'branch_std': self.join.std()} + + +@register_model +def register_tecogen(opt_net, opt): + return TecoGen(opt_net['nf'], opt_net['scale']) \ No newline at end of file diff --git a/codes/models/transformers/igpt/gpt2.py b/codes/models/transformers/igpt/gpt2.py index 552185e3..21a7e4fd 100644 --- a/codes/models/transformers/igpt/gpt2.py +++ b/codes/models/transformers/igpt/gpt2.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np from trainer.injectors import Injector +from trainer.networks import register_model from utils.util import checkpoint @@ -147,3 +148,8 @@ class iGPT2(nn.Module): return logits, x + +@register_model +def register_igpt2(opt_net, opt): + return iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, + opt_net['num_vocab'], centroids_file=opt_net['centroids_file']) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index ffec11b5..8298091f 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -57,7 +57,7 @@ class ExtensibleTrainer(BaseModel): new_net = None if net['type'] == 'generator': if new_net is None: - new_net = networks.define_G(opt, net, opt['scale']).to(self.device) + new_net = networks.create_model(opt, net, opt['scale']).to(self.device) self.netsG[name] = new_net elif net['type'] == 'discriminator': if new_net is None: diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py index 7c67eb67..422d6f91 100644 --- a/codes/trainer/networks.py +++ b/codes/trainer/networks.py @@ -1,140 +1,79 @@ import functools +import importlib import logging +import pkgutil +import sys from collections import OrderedDict +from inspect import isfunction, getmembers -import munch import torch import torchvision -from munch import munchify -import models.stylegan.stylegan2_lucidrains as stylegan2 -import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch -import models.RRDBNet_arch as RRDBNet_arch -import models.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.discriminator_vgg_arch as SRGAN_arch import models.feature_arch as feature_arch -from models import srg2_classic +import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator -from models.tecogan.teco_resgen import TecoGen -from utils.util import opt_get logger = logging.getLogger('base') -# Generator -def define_G(opt, opt_net, scale=None): - if scale is None: - scale = opt['scale'] - which_model = opt_net['which_model_G'] - if 'RRDBNet' in which_model: - if which_model == 'RRDBNetBypass': - block = RRDBNet_arch.RRDBWithBypass - elif which_model == 'RRDBNetLambda': - from models.lambda_rrdb import LambdaRRDB - block = LambdaRRDB - else: - block = RRDBNet_arch.RRDB - additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not' - output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' - gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32 - initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1 - netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], - mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, - output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc, - initial_stride=initial_stride) - elif which_model == "ConfigurableSwitchedResidualGenerator2": - netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'], - switch_reductions=opt_net['switch_reductions'], - switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], - trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], - transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'], - initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], - heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], - upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'], - for_video=opt_net['for_video']) - elif which_model == "srg2classic": - netG = srg2_classic.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'], - switch_reductions=opt_net['switch_reductions'], - switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], - trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], - transformation_filters=opt_net['transformation_filters'], - initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], - heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], - upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) - elif which_model == "flownet2": - from models.flownet2 import FlowNet2 - ld = 'load_path' in opt_net.keys() - args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld}) - netG = FlowNet2(args) - if ld: - sd = torch.load(opt_net['load_path']) - netG.load_state_dict(sd['state_dict']) - elif which_model == "backbone_encoder": - netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet']) - elif which_model == "backbone_encoder_no_ref": - netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet']) - elif which_model == "backbone_encoder_no_head": - netG = SwitchedGen_arch.BackboneSpinenetNoHead() - elif which_model == "backbone_resnet": - netG = SwitchedGen_arch.BackboneResnet() - elif which_model == "tecogen": - netG = TecoGen(opt_net['nf'], opt_net['scale']) - elif which_model == 'stylegan2': - is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False - attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] - netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], - style_depth=opt_net['style_depth'], structure_input=is_structured, - attn_layers=attn) - elif which_model == 'srflow': - from models.srflow import SRFlowNet_arch - netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], - K=opt_net['K'], opt=opt) - elif which_model == 'rrdb_latent_wrapper': - from models.srflow.RRDBNet_arch import RRDBLatentWrapper - netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'], - blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path']) - elif which_model == 'rrdb_centipede': - output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' - netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], - mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'], - headless=True, output_mode=output_mode) - elif which_model == 'rrdb_srflow': - from models.srflow.RRDBNet_arch import RRDBNet - netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], - initial_conv_stride=opt_net['initial_stride']) - elif which_model == 'igpt2': - from models.transformers.igpt.gpt2 import iGPT2 - netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file']) - elif which_model == 'byol': - from models.byol.byol_model_wrapper import BYOL - subnet = define_G(opt, opt_net['subnet']) - netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], - structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False), - do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False)) - elif which_model == 'structural_byol': - from models.byol.byol_structural import StructuralBYOL - subnet = define_G(opt, opt_net['subnet']) - netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], - pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]), - freeze_until=opt_get(opt_net, ['freeze_until'], 0)) - elif which_model == 'spinenet': - from models.spinenet_arch import SpineNet - netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm']) - elif which_model == 'spinenet_with_logits': - from models.spinenet_arch import SpinenetWithLogits - netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'], - in_channels=3, use_input_norm=opt_net['use_input_norm']) - elif which_model == 'resnet52': - from models.resnet_with_checkpointing import resnet50 - netG = resnet50(pretrained=opt_net['pretrained']) - elif which_model == 'glean': - from models.glean.glean import GleanGenerator - netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan']) +class RegisteredModelNameError(Exception): + def __init__(self, name_error): + super().__init__(f'Registered DLAS modules must start with `register_`. Incorrect registration: {name_error}') + + +# Decorator that allows API clients to show DLAS how to build a nn.Module from an opt dict. +# Functions with this decorator should have a specific naming format: +# `register_` where is the name that will be used in configuration files to reference this model. +# Functions with this decorator are expected to take a single argument: +# - opt: A dict with the configuration options for building the module. +# They should return: +# - A torch.nn.Module object for the model being defined. +def register_model(func): + if func.__name__.startswith("register_"): + func._dlas_model_name = func.__name__[9:] + assert func._dlas_model_name else: - raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) - return netG + raise RegisteredModelNameError(func.__name__) + func._dlas_registered_model = True + return func + + +def find_registered_model_fns(base_path='models'): + found_fns = {} + module_iter = pkgutil.walk_packages([base_path]) + for mod in module_iter: + if mod.ispkg: + EXCLUSION_LIST = ['flownet2'] + if mod.name not in EXCLUSION_LIST: + found_fns.update(find_registered_model_fns(f'{base_path}/{mod.name}')) + else: + mod_name = f'{base_path}/{mod.name}'.replace('/', '.') + importlib.import_module(mod_name) + for mod_fn in getmembers(sys.modules[mod_name], isfunction): + if hasattr(mod_fn[1], "_dlas_registered_model"): + found_fns[mod_fn[1]._dlas_model_name] = mod_fn[1] + return found_fns + + +class CreateModelError(Exception): + def __init__(self, name, available): + super().__init__(f'Could not find the specified model name: {name}. Tip: If your model is in a' + f' subdirectory, that directory must contain an __init__.py to be scanned. Available models:' + f'{available}') + + +def create_model(opt, opt_net, scale=None): + which_model = opt_net['which_model'] + # For backwards compatibility. + if not which_model: + which_model = opt_net['which_model_G'] + if not which_model: + which_model = opt_net['which_model_D'] + registered_fns = find_registered_model_fns() + if which_model not in registered_fns.keys(): + raise CreateModelError(which_model, list(registered_fns.keys())) + return registered_fns[which_model](opt_net, opt) class GradDiscWrapper(torch.nn.Module): diff --git a/codes/utils/distill_torchscript.py b/codes/utils/distill_torchscript.py index a3b1972f..819cb358 100644 --- a/codes/utils/distill_torchscript.py +++ b/codes/utils/distill_torchscript.py @@ -2,7 +2,7 @@ import argparse import functools import torch from utils import options as option -from trainer.networks import define_G +from trainer.networks import create_model class TracedModule: @@ -96,7 +96,7 @@ if __name__ == "__main__": opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) - netG = define_G(opt) + netG = create_model(opt) dummyInput = torch.rand(1,3,32,32) mode = 'onnx'