From 1b1431133bf397ee9a2b21373995dcdd26eb7533 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Tue, 14 Jul 2020 09:28:24 -0600
Subject: [PATCH] Add DualOutputSRG

Also removes the old multi-return mechanism that Generators support.
Also fixes AttentionNorm.
---
 codes/models/SRGAN_model.py                   |  93 ++++++++-------
 .../archs/SwitchedResidualGenerator_arch.py   | 107 +++++++++++++++++-
 codes/models/archs/discriminator_vgg_arch.py  |   3 -
 codes/models/networks.py                      |   9 ++
 codes/train.py                                |   2 +-
 codes/utils/numeric_stability.py              |   2 +-
 6 files changed, 161 insertions(+), 55 deletions(-)

diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py
index f95bbfdb..a47d007d 100644
--- a/codes/models/SRGAN_model.py
+++ b/codes/models/SRGAN_model.py
@@ -70,9 +70,11 @@ class SRGANModel(BaseModel):
                 else:
                     raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                 self.l_fea_w = train_opt['feature_weight']
-                self.l_fea_w_decay = train_opt['feature_weight_decay']
+                self.l_fea_w_decay_start = train_opt['feature_weight_decay_start']
                 self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
                 self.l_fea_w_minimum = train_opt['feature_weight_minimum']
+                if self.l_fea_w_decay_start:
+                    self.l_fea_w_decay_step_size = (self.l_fea_w - self.l_fea_w_minimum) / (self.l_fea_w_decay_steps)
             else:
                 logger.info('Remove feature loss.')
                 self.cri_fea = None
@@ -202,16 +204,17 @@ class SRGANModel(BaseModel):
         for p in self.netD.parameters():
             p.requires_grad = False
 
-        if step > self.D_init_iters:
+        if step >= self.D_init_iters:
             self.optimizer_G.zero_grad()
 
         self.swapout_D(step)
         self.swapout_G(step)
 
         # Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
-        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
+        if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
             for p in self.netG.parameters():
-                p.requires_grad = True
+                if p.dtype != torch.int64 and p.dtype != torch.bool:
+                    p.requires_grad = True
         else:
             for p in self.netG.parameters():
                 p.requires_grad = False
@@ -227,35 +230,28 @@ class SRGANModel(BaseModel):
             _t = time()
 
         self.fake_GenOut = []
+        self.fea_GenOut = []
         self.fake_H = []
         var_ref_skips = []
         for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
-            fake_GenOut = self.netG(var_L)
+            fea_GenOut, fake_GenOut = self.netG(var_L)
 
             if _profile:
                 print("Gen forward %f" % (time() - _t,))
                 _t = time()
 
-            # Extract the image output. For generators that output skip-through connections, the master output is always
-            # the first element of the tuple.
-            if isinstance(fake_GenOut, tuple):
-                gen_img = fake_GenOut[0]
-                # The following line detaches all generator outputs that are not None.
-                self.fake_GenOut.append(tuple([(x.detach() if x is not None else None) for x in list(fake_GenOut)]))
-                var_ref = (var_ref,)  # This is a tuple for legacy reasons.
-            else:
-                gen_img = fake_GenOut
-                self.fake_GenOut.append(fake_GenOut.detach())
+            self.fake_GenOut.append(fake_GenOut.detach())
+            self.fea_GenOut.append(fea_GenOut.detach())
 
             l_g_total = 0
-            if step % self.D_update_ratio == 0 and step > self.D_init_iters:
+            if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
                 if self.cri_pix:  # pixel loss
-                    l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix)
+                    l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
                     l_g_pix_log = l_g_pix / self.l_pix_w
                     l_g_total += l_g_pix
                 if self.cri_fea:  # feature loss
                     real_fea = self.netF(pix).detach()
-                    fake_fea = self.netF(gen_img)
+                    fake_fea = self.netF(fea_GenOut)
                     l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                     l_g_fea_log = l_g_fea / self.l_fea_w
                     l_g_total += l_g_fea
@@ -266,8 +262,13 @@ class SRGANModel(BaseModel):
 
                     # Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
                     # in the resultant image.
-                    if step % self.l_fea_w_decay_steps == 0:
-                        self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
+                    if self.l_fea_w_decay_start and step > self.l_fea_w_decay_start:
+                        self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w - self.l_fea_w_decay_step_size * (step - self.l_fea_w_decay_start))
+
+                    # Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931
+                    # Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is
+                    # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
+                    # it should target this value.
 
                 if self.l_gan_w > 0:
                     if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan':
@@ -304,7 +305,7 @@ class SRGANModel(BaseModel):
             for p in self.netD.parameters():
                 p.requires_grad = True
 
-            noise = torch.randn_like(var_ref[0]) * noise_theta
+            noise = torch.randn_like(var_ref) * noise_theta
             noise.to(self.device)
             self.optimizer_D.zero_grad()
             real_disc_images = []
@@ -312,17 +313,17 @@ class SRGANModel(BaseModel):
             for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
                 # Re-compute generator outputs (post-update).
                 with torch.no_grad():
-                    fake_H = self.netG(var_L)
+                    _, fake_H = self.netG(var_L)
                     # The following line detaches all generator outputs that are not None.
-                    fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
+                    fake_H = fake_H.detach()
 
                     if _profile:
                         print("Gen forward for disc %f" % (time() - _t,))
                         _t = time()
 
                 # Apply noise to the inputs to slow discriminator convergence.
-                var_ref = (var_ref + noise,)
-                fake_H = (fake_H[0] + noise,) + fake_H[1:]
+                var_ref = var_ref + noise
+                fake_H = fake_H + noise
                 if self.opt['train']['gan_type'] == 'gan':
                     # need to forward and backward separately, since batch norm statistics differ
                     # real
@@ -340,10 +341,10 @@ class SRGANModel(BaseModel):
                 if self.opt['train']['gan_type'] == 'pixgan':
                     # randomly determine portions of the image to swap to keep the discriminator honest.
                     pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
-                    disc_output_shape = (var_ref[0].shape[0], pixdisc_channels, var_ref[0].shape[2] // pixdisc_output_reduction, var_ref[0].shape[3] // pixdisc_output_reduction)
-                    b, _, w, h = var_ref[0].shape
-                    real = torch.ones((b, pixdisc_channels, w, h), device=var_ref[0].device)
-                    fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref[0].device)
+                    disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
+                    b, _, w, h = var_ref.shape
+                    real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
+                    fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
                     SWAP_MAX_DIM = w // 4
                     SWAP_MIN_DIM = 16
                     assert SWAP_MAX_DIM > 0
@@ -360,9 +361,9 @@ class SRGANModel(BaseModel):
                                 swap_w = w - swap_x
                             if swap_y + swap_h > h:
                                 swap_h = h - swap_y
-                            t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
-                            fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
-                            var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
+                            t = fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
+                            fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
+                            var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
                             real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
                             fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
 
@@ -422,8 +423,8 @@ class SRGANModel(BaseModel):
                         _t = time()
 
                 # Append var_ref here, so that we can inspect the alterations the disc made if pixgan
-                var_ref_skips.append(var_ref[0].detach())
-                self.fake_H.append(fake_H[0].detach())
+                var_ref_skips.append(var_ref.detach())
+                self.fake_H.append(fake_H.detach())
             self.optimizer_D.step()
 
 
@@ -436,32 +437,28 @@ class SRGANModel(BaseModel):
             sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
             os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
             os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
+            os.makedirs(os.path.join(sample_save_path, "gen_fea"), exist_ok=True)
             os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
             os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True)
             os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
             os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True)
-            multi_gen = False
-            if isinstance(self.fake_GenOut[0], tuple):
-                os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
-                multi_gen = True
+            os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
 
             # fed_LQ is not chunked.
             for i in range(self.mega_batch_factor):
                 utils.save_image(self.var_H[i].cpu(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i)))
                 utils.save_image(self.var_L[i].cpu(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i)))
                 utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
-                if multi_gen:
-                    utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
-                    if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan':
-                        utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
-                        utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
-                        utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
-                        utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i)))
-                else:
-                    utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
+                utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
+                utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i)))
+                if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan':
+                    utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
+                    utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
+                    utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
+                    utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i)))
 
         # Log metrics
-        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
+        if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
             if self.cri_pix:
                 self.add_log_entry('l_g_pix', l_g_pix_log.item())
             if self.cri_fea:
diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py
index 0cde385e..e42ec2a0 100644
--- a/codes/models/archs/SwitchedResidualGenerator_arch.py
+++ b/codes/models/archs/SwitchedResidualGenerator_arch.py
@@ -196,7 +196,8 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
         if self.upsample_factor > 2:
             x = F.interpolate(x, scale_factor=2, mode="nearest")
         x = self.upconv2(x)
-        return self.final_conv(self.hr_conv(x)),
+        x = self.final_conv(self.hr_conv(x))
+        return x, x
 
     def set_temperature(self, temp):
         [sw.set_temperature(temp) for sw in self.switches]
@@ -318,4 +319,106 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
         for i in range(len(means)):
             val["switch_%i_specificity" % (i,)] = means[i]
             val["switch_%i_histogram" % (i,)] = hists[i]
-        return val
\ No newline at end of file
+        return val
+
+
+class DualOutputSRG(nn.Module):
+    def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
+                 trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
+                 heightened_final_step=50000, upsample_factor=1,
+                 add_scalable_noise_to_transforms=False):
+        super(DualOutputSRG, self).__init__()
+        switches = []
+        self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
+
+        self.fea_upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
+        self.fea_upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
+        self.fea_hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
+        self.fea_final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
+
+        self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
+        self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
+        self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
+        self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
+
+        for _ in range(switch_depth):
+            multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts)
+            pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
+            transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
+            switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
+                                                       pre_transform_block=pretransform_fn, transform_block=transform_fn,
+                                                       transform_count=trans_counts, init_temp=initial_temp,
+                                                       add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
+
+        self.switches = nn.ModuleList(switches)
+        self.transformation_counts = trans_counts
+        self.init_temperature = initial_temp
+        self.final_temperature_step = final_temperature_step
+        self.heightened_temp_min = heightened_temp_min
+        self.heightened_final_step = heightened_final_step
+        self.attentions = None
+        self.upsample_factor = upsample_factor
+        assert self.upsample_factor == 2 or self.upsample_factor == 4
+
+    def forward(self, x):
+        x = self.initial_conv(x)
+
+        self.attentions = []
+        for i, sw in enumerate(self.switches):
+            x, att = sw.forward(x, True)
+            self.attentions.append(att)
+
+            if i == len(self.switches)-2:
+                fea = self.fea_upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
+                if self.upsample_factor > 2:
+                    fea = F.interpolate(fea, scale_factor=2, mode="nearest")
+                fea = self.fea_upconv2(fea)
+                fea = self.fea_final_conv(self.hr_conv(fea))
+
+        x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
+        if self.upsample_factor > 2:
+            x = F.interpolate(x, scale_factor=2, mode="nearest")
+        x = self.upconv2(x)
+        return fea, self.final_conv(self.hr_conv(x))
+
+    def set_temperature(self, temp):
+        [sw.set_temperature(temp) for sw in self.switches]
+
+    def update_for_step(self, step, experiments_path='.'):
+        if self.attentions:
+            temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
+            if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
+                # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
+                # without this, the attention specificity "spikes" incredibly fast in the last few iterations.
+                h_steps_total = self.heightened_final_step - self.final_temperature_step
+                h_steps_current = max(min(step - self.final_temperature_step, h_steps_total), 1)
+                # The "gap" will represent the steps that need to be traveled as a linear function.
+                h_gap = 1 / self.heightened_temp_min
+                temp = h_gap * h_steps_current / h_steps_total
+                # Invert temperature to represent reality on this side of the curve
+                temp = 1 / temp
+            self.set_temperature(temp)
+            if step % 50 == 0:
+                [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,)) for i in range(len(self.switches))]
+
+    def get_debug_values(self, step):
+        temp = self.switches[0].switch.temperature
+        mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
+        means = [i[0] for i in mean_hists]
+        hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
+        val = {"switch_temperature": temp}
+        for i in range(len(means)):
+            val["switch_%i_specificity" % (i,)] = means[i]
+            val["switch_%i_histogram" % (i,)] = hists[i]
+        return val
+
+
+    def load_state_dict(self, state_dict, strict=True):
+        # Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
+        t_state = self.state_dict()
+        if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
+            for i in range(4):
+                state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
+                state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
+                state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
+        super(DualOutputSRG, self).load_state_dict(state_dict, strict)
\ No newline at end of file
diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py
index 559c2f16..e20f7e9f 100644
--- a/codes/models/archs/discriminator_vgg_arch.py
+++ b/codes/models/archs/discriminator_vgg_arch.py
@@ -51,7 +51,6 @@ class Discriminator_VGG_128(nn.Module):
         self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
 
     def forward(self, x):
-        x = x[0]
         fea = self.lrelu(self.conv0_0(x))
         fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
 
@@ -127,7 +126,6 @@ class Discriminator_VGG_PixLoss(nn.Module):
         self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
 
     def forward(self, x, flatten=True):
-        x = x[0]
         fea0 = self.lrelu(self.conv0_0(x))
         fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0)))
 
@@ -205,7 +203,6 @@ class Discriminator_UNet(nn.Module):
         self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
 
     def forward(self, x, flatten=True):
-        x = x[0]
         fea0 = self.conv0_0(x)
         fea0 = self.conv0_1(fea0)
 
diff --git a/codes/models/networks.py b/codes/models/networks.py
index 7507add2..2dfe0af0 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -78,6 +78,15 @@ def define_G(opt, net_key='network_G'):
                                                                       initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
                                                                       heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
                                                                       upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
+    elif which_model == "DualOutputSRG":
+        netG = SwitchedGen_arch.DualOutputSRG(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
+                                                                      switch_reductions=opt_net['switch_reductions'],
+                                                                      switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
+                                                                      trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
+                                                                      transformation_filters=opt_net['transformation_filters'],
+                                                                      initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
+                                                                      heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
+                                                                      upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
 
     # image corruption
     elif which_model == 'HighToLowResNet':
diff --git a/codes/train.py b/codes/train.py
index 4744ce26..65395714 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
 def main():
     #### options
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_dual_srg.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                         help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
diff --git a/codes/utils/numeric_stability.py b/codes/utils/numeric_stability.py
index bfcedc70..dda9f79b 100644
--- a/codes/utils/numeric_stability.py
+++ b/codes/utils/numeric_stability.py
@@ -93,7 +93,7 @@ if __name__ == "__main__":
                    torch.randn(1, 3, 64, 64),
                    device='cuda')
     '''
-    test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
+    test_stability(functools.partial(srg.DualOutputSRG,
                                      switch_depth=4,
                                      switch_filters=64,
                                      switch_reductions=4,