From 5f2c722a10c1caa676e1ed97b37618d3dc319cac Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Thu, 9 Jul 2020 17:34:51 -0600
Subject: [PATCH] SRG2 revival

Big update to SRG2 architecture to pull in a lot of things that have been learned:
- Use group norm instead of batch norm
- Initialize the weights on the transformations low like is done in RRDB rather than using the scalar. Models live or die by their early stages, and this ones early stage is pretty weak
- Transform multiplexer to use u-net like architecture.
- Just use one set of configuration variables instead of a list - flat networks performed fine in this regard.
---
 codes/models/archs/NestedSwitchGenerator.py   |   3 +-
 .../archs/SwitchedResidualGenerator_arch.py   | 121 +++++++++---------
 codes/models/archs/arch_util.py               |  47 ++++++-
 codes/models/networks.py                      |   2 +-
 codes/train.py                                |   2 +-
 codes/utils/numeric_stability.py              |  23 ++--
 6 files changed, 124 insertions(+), 74 deletions(-)

diff --git a/codes/models/archs/NestedSwitchGenerator.py b/codes/models/archs/NestedSwitchGenerator.py
index f663f4cc..ec2d4dc8 100644
--- a/codes/models/archs/NestedSwitchGenerator.py
+++ b/codes/models/archs/NestedSwitchGenerator.py
@@ -1,6 +1,7 @@
 import torch
 from torch import nn
-from models.archs.SwitchedResidualGenerator_arch import ConvBnLelu, ConvBnRelu, MultiConvBlock, initialize_weights
+from models.archs.arch_util import ConvBnLelu, ConvBnRelu
+from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock
 from switched_conv import BareConvSwitch, compute_attention_specificity
 from switched_conv_util import save_attention_to_image
 from functools import partial
diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py
index 3ebb6ea5..ec0f4fe1 100644
--- a/codes/models/archs/SwitchedResidualGenerator_arch.py
+++ b/codes/models/archs/SwitchedResidualGenerator_arch.py
@@ -4,20 +4,20 @@ from switched_conv import BareConvSwitch, compute_attention_specificity
 import torch.nn.functional as F
 import functools
 from collections import OrderedDict
-from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu, ConvBnSilu
+from models.archs.arch_util import ConvBnLelu, ConvGnSilu
 from models.archs.RRDBNet_arch import ResidualDenseBlock_5C
 from models.archs.spinenet_arch import SpineNet
 from switched_conv_util import save_attention_to_image
 
 
 class MultiConvBlock(nn.Module):
-    def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False):
+    def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False, weight_init_factor=1):
         assert depth >= 2
         super(MultiConvBlock, self).__init__()
         self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
-        self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False)] +
-                                     [ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False) for i in range(depth-2)] +
-                                     [ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False)])
+        self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False, weight_init_factor=weight_init_factor)] +
+                                     [ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False, weight_init_factor=weight_init_factor) for i in range(depth-2)] +
+                                     [ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False, weight_init_factor=weight_init_factor)])
         self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
         self.bias = nn.Parameter(torch.zeros(1))
 
@@ -35,43 +35,56 @@ class MultiConvBlock(nn.Module):
 class HalvingProcessingBlock(nn.Module):
     def __init__(self, filters):
         super(HalvingProcessingBlock, self).__init__()
-        self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2, bn=False, bias=False)
-        self.bnconv2 = ConvBnLelu(filters * 2, filters * 2, bn=True, bias=False)
+        self.bnconv1 = ConvGnSilu(filters, filters * 2, stride=2, gn=False, bias=False)
+        self.bnconv2 = ConvGnSilu(filters * 2, filters * 2, gn=True, bias=False)
 
     def forward(self, x):
         x = self.bnconv1(x)
         return self.bnconv2(x)
 
 
-# Creates a nested series of convolutional blocks. Each block processes the input data in-place and adds
-# filter_growth filters. Return is (nn.Sequential, ending_filters)
-def create_sequential_growing_processing_block(filters_init, filter_growth, num_convs):
-    convs = []
-    current_filters = filters_init
-    for i in range(num_convs):
-        convs.append(ConvBnSilu(current_filters, current_filters + filter_growth, bn=True, bias=False))
-        current_filters += filter_growth
-    return nn.Sequential(*convs), current_filters
+class ExpansionBlock(nn.Module):
+    def __init__(self, filters):
+        super(ExpansionBlock, self).__init__()
+        self.decimate = ConvGnSilu(filters, filters // 2, kernel_size=1, bias=False, silu=False, gn=False)
+        self.conjoin = ConvGnSilu(filters, filters // 2, kernel_size=3, bias=True, silu=False, gn=True)
+        self.process = ConvGnSilu(filters // 2, filters // 2, kernel_size=3, bias=False, silu=True, gn=True)
+
+    def forward(self, input, passthrough):
+        x = F.interpolate(input, scale_factor=2, mode="nearest")
+        x = self.decimate(x)
+        x = self.conjoin(torch.cat([x, passthrough], dim=1))
+        return self.process(x)
 
 
+# This is a classic u-net architecture with the goal of assigning each individual pixel an individual transform
+# switching set.
 class ConvBasisMultiplexer(nn.Module):
-    def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True):
+    def __init__(self, input_channels, base_filters, reductions, processing_depth, multiplexer_channels, use_gn=True):
         super(ConvBasisMultiplexer, self).__init__()
-        self.filter_conv = ConvBnSilu(input_channels, base_filters, bias=True)
-        self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
+        self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True)
+        self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(reductions)])
         reduction_filters = base_filters * 2 ** reductions
-        self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
+        self.processing_blocks = nn.Sequential(OrderedDict([('block%i' % (i,), ConvGnSilu(reduction_filters, reduction_filters, bias=False)) for i in range(processing_depth)]))
+        self.expansion_blocks = nn.ModuleList([ExpansionBlock(reduction_filters // (2 ** i)) for i in range(reductions)])
 
-        gap = self.output_filter_count - multiplexer_channels
-        # Hey silly - if you're going to interpolate later, do it here instead. Then add some processing layers to let the model adjust it properly.
-        self.cbl1 = ConvBnSilu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False)
-        self.cbl2 = ConvBnSilu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False)
-        self.cbl3 = ConvBnSilu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True)
+        gap = base_filters - multiplexer_channels
+        cbl1_out = ((base_filters - (gap // 2)) // 4) * 4   # Must be multiples of 4 to use with group norm.
+        self.cbl1 = ConvGnSilu(base_filters, cbl1_out, gn=use_gn, bias=False, num_groups=4)
+        cbl2_out = ((base_filters - (3 * gap // 4)) // 4) * 4
+        self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, gn=use_gn, bias=False, num_groups=4)
+        self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, gn=False)
 
     def forward(self, x):
         x = self.filter_conv(x)
-        x = self.reduction_blocks(x)
+        reduction_identities = []
+        for b in self.reduction_blocks:
+            reduction_identities.append(x)
+            x = b(x)
         x = self.processing_blocks(x)
+        for i, b in enumerate(self.expansion_blocks):
+            x = b(x, reduction_identities[-i - 1])
+
         x = self.cbl1(x)
         x = self.cbl2(x)
         x = self.cbl3(x)
@@ -94,13 +107,13 @@ class BackboneMultiplexer(nn.Module):
     def __init__(self, backbone: CachedBackboneWrapper, transform_count):
         super(BackboneMultiplexer, self).__init__()
         self.backbone = backbone
-        self.proc = nn.Sequential(ConvBnSilu(256, 256, kernel_size=3, bias=True),
-                                  ConvBnSilu(256, 256, kernel_size=3, bias=False))
-        self.up1 = nn.Sequential(ConvBnSilu(256, 128, kernel_size=3, bias=False, bn=False, silu=False),
-                                 ConvBnSilu(128, 128, kernel_size=3, bias=False))
-        self.up2 = nn.Sequential(ConvBnSilu(128, 64, kernel_size=3, bias=False, bn=False, silu=False),
-                                 ConvBnSilu(64, 64, kernel_size=3, bias=False))
-        self.final = ConvBnSilu(64, transform_count, bias=False, bn=False, silu=False)
+        self.proc = nn.Sequential(ConvGnSilu(256, 256, kernel_size=3, bias=True),
+                                  ConvGnSilu(256, 256, kernel_size=3, bias=False))
+        self.up1 = nn.Sequential(ConvGnSilu(256, 128, kernel_size=3, bias=False, gn=False, silu=False),
+                                 ConvGnSilu(128, 128, kernel_size=3, bias=False))
+        self.up2 = nn.Sequential(ConvGnSilu(128, 64, kernel_size=3, bias=False, gn=False, silu=False),
+                                 ConvGnSilu(64, 64, kernel_size=3, bias=False))
+        self.final = ConvGnSilu(64, transform_count, bias=False, gn=False, silu=False)
 
     def forward(self, x):
         spine = self.backbone.get_forward_result()
@@ -112,13 +125,10 @@ class BackboneMultiplexer(nn.Module):
 
 class ConfigurableSwitchComputer(nn.Module):
     def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20,
-                 enable_negative_transforms=False, add_scalable_noise_to_transforms=False, init_scalar=1):
+                 add_scalable_noise_to_transforms=False):
         super(ConfigurableSwitchComputer, self).__init__()
-        self.enable_negative_transforms = enable_negative_transforms
 
         tc = transform_count
-        if self.enable_negative_transforms:
-            tc = transform_count * 2
         self.multiplexer = multiplexer_net(tc)
 
         self.pre_transform = pre_transform_block()
@@ -128,11 +138,11 @@ class ConfigurableSwitchComputer(nn.Module):
 
         # And the switch itself, including learned scalars
         self.switch = BareConvSwitch(initial_temperature=init_temp)
-        self.switch_scale = nn.Parameter(torch.full((1,), float(init_scalar)))
-        self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=False)
-        # The post_switch_conv gets a near-zero scale. The network can decide to magnify it (or not) depending on its needs.
-        self.psc_scale = nn.Parameter(torch.full((1,), float(1e-3)))
-        self.bias = nn.Parameter(torch.zeros(1))
+        self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
+        self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=True)
+        # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
+        # depending on its needs.
+        self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
 
     def forward(self, x, output_attention_weights=False):
         identity = x
@@ -142,17 +152,12 @@ class ConfigurableSwitchComputer(nn.Module):
 
         x = self.pre_transform(x)
         xformed = [t.forward(x) for t in self.transforms]
-        if self.enable_negative_transforms:
-            xformed.extend([-t for t in xformed])
 
         m = self.multiplexer(identity)
-        # Interpolate the multiplexer across the entire shape of the image.
-        m = F.interpolate(m, size=xformed[0].shape[2:], mode='nearest')
 
         outputs, attention = self.switch(xformed, m, True)
         outputs = identity + outputs * self.switch_scale
-        outputs = identity + self.post_switch_conv(outputs) * self.psc_scale
-        outputs = outputs + self.bias
+        outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale
         if output_attention_weights:
             return outputs, attention
         else:
@@ -163,25 +168,25 @@ class ConfigurableSwitchComputer(nn.Module):
 
 
 class ConfigurableSwitchedResidualGenerator2(nn.Module):
-    def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
+    def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
                  trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
-                 heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False,
+                 heightened_final_step=50000, upsample_factor=1,
                  add_scalable_noise_to_transforms=False):
         super(ConfigurableSwitchedResidualGenerator2, self).__init__()
         switches = []
         self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False, lelu=False, bias=True)
-        self.sw_conv = ConvBnLelu(transformation_filters, transformation_filters, lelu=False, bias=True)
         self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
         self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
         self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
         self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False, bias=True)
-        for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
-            multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
+        for _ in range(switch_depth):
+            multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts)
+            pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False, weight_init_factor=.1)
+            transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
             switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
-                                                       pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False),
-                                                       transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters + growth, transformation_filters, kernel_size=kernel, depth=layers),
-                                                       transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms,
-                                                       add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.1))
+                                                       pre_transform_block=pretransform_fn, transform_block=transform_fn,
+                                                       transform_count=trans_counts, init_temp=initial_temp,
+                                                       add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
 
         self.switches = nn.ModuleList(switches)
         self.transformation_counts = trans_counts
@@ -268,7 +273,7 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
         pretransform_fn = functools.partial(nn.Sequential, ConvBnLelu(base_filters, base_filters, kernel_size=3, bn=False, lelu=False, bias=False))
         transform_fn = functools.partial(MultiConvBlock, base_filters, int(base_filters * 1.5), base_filters, kernel_size=3, depth=4)
         self.switch = ConfigurableSwitchComputer(base_filters, multiplx_fn, pretransform_fn, transform_fn, trans_count, init_temp=initial_temp,
-                                            enable_negative_transforms=False, add_scalable_noise_to_transforms=True, init_scalar=.1)
+                                            add_scalable_noise_to_transforms=True, init_scalar=.1)
 
         self.transformation_counts = trans_count
         self.init_temperature = initial_temp
diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py
index afb99703..3cc7df98 100644
--- a/codes/models/archs/arch_util.py
+++ b/codes/models/archs/arch_util.py
@@ -219,7 +219,7 @@ class ConvBnRelu(nn.Module):
 ''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
     kernel sizes. '''
 class ConvBnSilu(nn.Module):
-    def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, bn=True, bias=True):
+    def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, bn=True, bias=True, weight_init_factor=1):
         super(ConvBnSilu, self).__init__()
         padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
         assert kernel_size in padding_map.keys()
@@ -237,6 +237,9 @@ class ConvBnSilu(nn.Module):
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
+                m.weight.data *= weight_init_factor
+                if m.bias is not None:
+                    m.bias.data.zero_()
             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
@@ -254,7 +257,7 @@ class ConvBnSilu(nn.Module):
 ''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
     kernel sizes. '''
 class ConvBnLelu(nn.Module):
-    def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True):
+    def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True, weight_init_factor=1):
         super(ConvBnLelu, self).__init__()
         padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
         assert kernel_size in padding_map.keys()
@@ -273,6 +276,9 @@ class ConvBnLelu(nn.Module):
             if isinstance(m, nn.Conv2d):
                 nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
                                         nonlinearity='leaky_relu' if self.lelu else 'linear')
+                m.weight.data *= weight_init_factor
+                if m.bias is not None:
+                    m.bias.data.zero_()
             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
@@ -319,5 +325,42 @@ class ConvGnLelu(nn.Module):
             x = self.gn(x)
         if self.lelu:
             return self.lelu(x)
+        else:
+            return x
+
+''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
+    kernel sizes. '''
+class ConvGnSilu(nn.Module):
+    def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, gn=True, bias=True, num_groups=8, weight_init_factor=1):
+        super(ConvGnSilu, self).__init__()
+        padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
+        assert kernel_size in padding_map.keys()
+        self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
+        if gn:
+            self.gn = nn.GroupNorm(num_groups, filters_out)
+        else:
+            self.gn = None
+        if silu:
+            self.silu = SiLU()
+        else:
+            self.silu = None
+
+        # Init params.
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
+                m.weight.data *= weight_init_factor
+                if m.bias is not None:
+                    m.bias.data.zero_()
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        x = self.conv(x)
+        if self.gn:
+            x = self.gn(x)
+        if self.silu:
+            return self.silu(x)
         else:
             return x
\ No newline at end of file
diff --git a/codes/models/networks.py b/codes/models/networks.py
index 23270e72..3495dd16 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -59,7 +59,7 @@ def define_G(opt, net_key='network_G'):
                                                                       heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
                                                                       upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
     elif which_model == "ConfigurableSwitchedResidualGenerator2":
-        netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'],
+        netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
                                                                       switch_reductions=opt_net['switch_reductions'],
                                                                       switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
                                                                       trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
diff --git a/codes/train.py b/codes/train.py
index 6ad0f763..9259c31a 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
 def main():
     #### options
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_rrdb.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_srg2.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                         help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
diff --git a/codes/utils/numeric_stability.py b/codes/utils/numeric_stability.py
index 28b3c0d6..4aa40ba9 100644
--- a/codes/utils/numeric_stability.py
+++ b/codes/utils/numeric_stability.py
@@ -97,20 +97,19 @@ if __name__ == "__main__":
                    torch.randn(1, 3, 64, 64),
                    device='cuda')
     '''
-    '''
     test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
-                                     switch_filters=[32,32,32,32],
-                                     switch_growths=[16,16,16,16],
-                                     switch_reductions=[4,3,2,1],
-                                     switch_processing_layers=[3,3,4,5],
-                                     trans_counts=[16,16,16,16,16],
-                                     trans_kernel_sizes=[3,3,3,3,3],
-                                     trans_layers=[3,3,3,3,3],
+                                     switch_depth=4,
+                                     switch_filters=64,
+                                     switch_reductions=4,
+                                     switch_processing_layers=2,
+                                     trans_counts=8,
+                                     trans_kernel_sizes=3,
+                                     trans_layers=4,
                                      transformation_filters=64,
-                                     initial_temp=10),
+                                     upsample_factor=4),
                    torch.randn(1, 3, 64, 64),
                    device='cuda')
-    '''
+
     '''
     test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
                                      switch_filters=[32,32,32,32],
@@ -125,7 +124,9 @@ if __name__ == "__main__":
                    torch.randn(1, 3, 64, 64),
                    device='cuda')
                    '''
+    '''
     test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3,
                                      64, 16),
                    torch.randn(1, 3, 64, 64),
-                   device='cuda')
\ No newline at end of file
+                   device='cuda')
+    '''
\ No newline at end of file