diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py
index 60baf25b..a1b2d5d4 100644
--- a/codes/models/SRGAN_model.py
+++ b/codes/models/SRGAN_model.py
@@ -116,9 +116,16 @@ class SRGANModel(BaseModel):
                                                 weight_decay=wd_G,
                                                 betas=(train_opt['beta1_G'], train_opt['beta2_G']))
             self.optimizers.append(self.optimizer_G)
+            optim_params = []
+            for k, v in self.netD.named_parameters():  # can optimize for a part of the model
+                if v.requires_grad:
+                    optim_params.append(v)
+                else:
+                    if self.rank <= 0:
+                        logger.warning('Params [{:s}] will not optimize.'.format(k))
             # D
             wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
-            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'],
+            self.optimizer_D = torch.optim.Adam(optim_params, lr=train_opt['lr_D'],
                                                 weight_decay=wd_D,
                                                 betas=(train_opt['beta1_D'], train_opt['beta2_D']))
             self.optimizers.append(self.optimizer_D)
@@ -219,6 +226,8 @@ class SRGANModel(BaseModel):
         # Some generators have variants depending on the current step.
         if hasattr(self.netG.module, "update_for_step"):
             self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
+        if hasattr(self.netD.module, "update_for_step"):
+            self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
 
         # G
         for p in self.netD.parameters():
@@ -323,7 +332,8 @@ class SRGANModel(BaseModel):
         # D
         if self.l_gan_w > 0 and step > self.G_warmup:
             for p in self.netD.parameters():
-                p.requires_grad = True
+                if p.dtype != torch.int64 and p.dtype != torch.bool:
+                    p.requires_grad = True
 
             noise = torch.randn_like(var_ref) * noise_theta
             noise.to(self.device)
@@ -610,6 +620,8 @@ class SRGANModel(BaseModel):
         # Some generators can do their own metric logging.
         if hasattr(self.netG.module, "get_debug_values"):
             return_log.update(self.netG.module.get_debug_values(step))
+        if hasattr(self.netD.module, "get_debug_values"):
+            return_log.update(self.netD.module.get_debug_values(step))
 
         return return_log
 
diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py
index 5dc8df23..04c9fc4d 100644
--- a/codes/models/archs/discriminator_vgg_arch.py
+++ b/codes/models/archs/discriminator_vgg_arch.py
@@ -238,6 +238,155 @@ class Discriminator_UNet(nn.Module):
         return 3, 4
 
 
+import functools
+from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConfigurableSwitchComputer, BareConvSwitch
+from switched_conv_util import save_attention_to_image
+from switched_conv import compute_attention_specificity, AttentionNorm
+
+
+class ExpandAndCollapse(nn.Module):
+    def __init__(self, nf, nf_out, num_channels):
+        super(ExpandAndCollapse, self).__init__()
+        self.expand = ExpansionBlock(nf, nf_out, block=ConvGnLelu)
+        self.collapse = ConvGnLelu(nf_out, num_channels, norm=False, bias=False, activation=False)
+
+    def forward(self, x, passthrough):
+        x = self.expand(x, passthrough)
+        return self.collapse(x)
+
+
+# Differs from ConfigurableSwitchComputer in that the connections are not residual and the multiplexer is fed directly in.
+class ConfigurableLinearSwitchComputer(nn.Module):
+    def __init__(self, out_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
+                 init_temp=20, add_scalable_noise_to_transforms=False):
+        super(ConfigurableLinearSwitchComputer, self).__init__()
+
+        self.multiplexer = multiplexer_net
+        self.pre_transform = pre_transform_block
+        self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
+        self.add_noise = add_scalable_noise_to_transforms
+        self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
+
+        # And the switch itself, including learned scalars
+        self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None)
+        self.post_switch_conv = ConvBnLelu(out_filters, out_filters, norm=False, bias=True)
+        # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
+        # depending on its needs.
+        self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
+
+    def forward(self, x, passthrough, output_attention_weights=False, extra_arg=None):
+        identity = x
+        if self.add_noise:
+            rand_feature = torch.randn_like(x) * self.noise_scale
+            x = x + rand_feature
+
+        x = self.pre_transform(x)
+        xformed = [t.forward(x, passthrough) for t in self.transforms]
+        m = self.multiplexer(identity, passthrough)
+
+
+        outputs, attention = self.switch(xformed, m, True)
+        outputs = self.post_switch_conv(outputs)
+        if output_attention_weights:
+            return outputs, attention
+        else:
+            return outputs
+
+    def set_temperature(self, temp):
+        self.switch.set_attention_temperature(temp)
+
+
+def create_switched_upsampler(nf, nf_out, num_channels, initial_temp=10):
+    multiplx = ExpandAndCollapse(nf, nf_out, num_channels)
+    pretransform = ConvGnLelu(nf, nf, norm=True, bias=False)
+    transform_fn = functools.partial(ExpansionBlock, nf, nf_out, block=ConvGnLelu)
+    return ConfigurableLinearSwitchComputer(nf_out, multiplx,
+                                       pre_transform_block=pretransform, transform_block=transform_fn,
+                                       attention_norm=True,
+                                       transform_count=num_channels, init_temp=initial_temp,
+                                       add_scalable_noise_to_transforms=False)
+
+
+class Discriminator_switched(nn.Module):
+    def __init__(self, in_nc, nf, initial_temp=10, final_temperature_step=50000):
+        super(Discriminator_switched, self).__init__()
+        # [64, 128, 128]
+        self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
+        self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
+        # [64, 64, 64]
+        self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
+        self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
+        # [128, 32, 32]
+        self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
+        self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
+        # [256, 16, 16]
+        self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
+        self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
+        # [512, 8, 8]
+        self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
+        self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
+
+        self.exp1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
+        self.upsw2 = create_switched_upsampler(nf * 8, nf * 4, 8)
+        self.upsw3 = create_switched_upsampler(nf * 4, nf * 2, 8)
+        self.switches = [self.upsw2, self.upsw3]
+        self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
+        self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
+
+        self.init_temperature = initial_temp
+        self.final_temperature_step = final_temperature_step
+        self.attentions = None
+
+    def forward(self, x, flatten=True):
+        fea0 = self.conv0_0(x)
+        fea0 = self.conv0_1(fea0)
+
+        fea1 = self.conv1_0(fea0)
+        fea1 = self.conv1_1(fea1)
+
+        fea2 = self.conv2_0(fea1)
+        fea2 = self.conv2_1(fea2)
+
+        fea3 = self.conv3_0(fea2)
+        fea3 = self.conv3_1(fea3)
+
+        fea4 = self.conv4_0(fea3)
+        fea4 = self.conv4_1(fea4)
+
+        u1 = self.exp1(fea4, fea3)
+        u2, a1 = self.upsw2(u1, fea2, output_attention_weights=True)
+        u3, a2 = self.upsw3(u2, fea1, output_attention_weights=True)
+        self.attentions = [a1, a2]
+        loss3 = self.collapse3(self.proc3(u3))
+        return loss3.view(-1, 1)
+
+    def pixgan_parameters(self):
+        return 1, 4
+
+    def set_temperature(self, temp):
+        [sw.set_temperature(temp) for sw in self.switches]
+
+    def update_for_step(self, step, experiments_path='.'):
+        if self.attentions:
+            for i, sw in enumerate(self.switches):
+                temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
+                sw.set_temperature(min(self.init_temperature,
+                                       max(self.init_temperature - temp_loss_per_step * step, 1)))
+            if step % 50 == 0:
+                [save_attention_to_image(experiments_path, self.attentions[i], 8, step, "disc_a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))]
+
+    def get_debug_values(self, step):
+        temp = self.switches[0].switch.temperature
+        mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
+        means = [i[0] for i in mean_hists]
+        hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
+        val = {"disc_switch_temperature": temp}
+        for i in range(len(means)):
+            val["disc_switch_%i_specificity" % (i,)] = means[i]
+            val["disc_switch_%i_histogram" % (i,)] = hists[i]
+        return val
+
+
 class Discriminator_UNet_FeaOut(nn.Module):
     def __init__(self, in_nc, nf, feature_mode=False):
         super(Discriminator_UNet_FeaOut, self).__init__()
diff --git a/codes/models/networks.py b/codes/models/networks.py
index 84243f7e..d666cc94 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -124,6 +124,9 @@ def define_D(opt):
         netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
     elif which_model == "discriminator_unet_fea":
         netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
+    elif which_model == "discriminator_switched":
+        netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
+                                                    final_temperature_step=opt_net['final_temperature_step'])
     else:
         raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
     return netD
diff --git a/codes/train.py b/codes/train.py
index f879a0af..ca12db0f 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
 def main():
     #### options
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_progressive_srg2.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_switched_disc.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                         help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
@@ -161,7 +161,7 @@ def main():
         current_step = resume_state['iter']
         model.resume_training(resume_state)  # handle optimizers and schedulers
     else:
-        current_step = 0
+        current_step = -1
         start_epoch = 0
 
     #### training
diff --git a/codes/utils/convert_model.py b/codes/utils/convert_model.py
index 6d31aac6..2f98ba8a 100644
--- a/codes/utils/convert_model.py
+++ b/codes/utils/convert_model.py
@@ -42,8 +42,6 @@ def copy_state_dict(dict_from, dict_to):
 
 if __name__ == "__main__":
     os.chdir("..")
-    torch.backends.cudnn.benchmark = True
-    want_just_images = True
     model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml")
     model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml")