From dbf61475040310645579c1b7953feb21d3a2ae3e Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 22 Jul 2020 20:52:59 -0600 Subject: [PATCH] Add switched discriminator The logic is that the discriminator may be incapable of providing a truly targeted loss for all image regions since it has to be too generic (basically the same argument for the switched generator). So add some switches in! See how it works! --- codes/models/SRGAN_model.py | 16 +- codes/models/archs/discriminator_vgg_arch.py | 149 +++++++++++++++++++ codes/models/networks.py | 3 + codes/train.py | 4 +- codes/utils/convert_model.py | 2 - 5 files changed, 168 insertions(+), 6 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 60baf25b..a1b2d5d4 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -116,9 +116,16 @@ class SRGANModel(BaseModel): weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) + optim_params = [] + for k, v in self.netD.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 - self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], + self.optimizer_D = torch.optim.Adam(optim_params, lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) @@ -219,6 +226,8 @@ class SRGANModel(BaseModel): # Some generators have variants depending on the current step. if hasattr(self.netG.module, "update_for_step"): self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) + if hasattr(self.netD.module, "update_for_step"): + self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) # G for p in self.netD.parameters(): @@ -323,7 +332,8 @@ class SRGANModel(BaseModel): # D if self.l_gan_w > 0 and step > self.G_warmup: for p in self.netD.parameters(): - p.requires_grad = True + if p.dtype != torch.int64 and p.dtype != torch.bool: + p.requires_grad = True noise = torch.randn_like(var_ref) * noise_theta noise.to(self.device) @@ -610,6 +620,8 @@ class SRGANModel(BaseModel): # Some generators can do their own metric logging. if hasattr(self.netG.module, "get_debug_values"): return_log.update(self.netG.module.get_debug_values(step)) + if hasattr(self.netD.module, "get_debug_values"): + return_log.update(self.netD.module.get_debug_values(step)) return return_log diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 5dc8df23..04c9fc4d 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -238,6 +238,155 @@ class Discriminator_UNet(nn.Module): return 3, 4 +import functools +from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConfigurableSwitchComputer, BareConvSwitch +from switched_conv_util import save_attention_to_image +from switched_conv import compute_attention_specificity, AttentionNorm + + +class ExpandAndCollapse(nn.Module): + def __init__(self, nf, nf_out, num_channels): + super(ExpandAndCollapse, self).__init__() + self.expand = ExpansionBlock(nf, nf_out, block=ConvGnLelu) + self.collapse = ConvGnLelu(nf_out, num_channels, norm=False, bias=False, activation=False) + + def forward(self, x, passthrough): + x = self.expand(x, passthrough) + return self.collapse(x) + + +# Differs from ConfigurableSwitchComputer in that the connections are not residual and the multiplexer is fed directly in. +class ConfigurableLinearSwitchComputer(nn.Module): + def __init__(self, out_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm, + init_temp=20, add_scalable_noise_to_transforms=False): + super(ConfigurableLinearSwitchComputer, self).__init__() + + self.multiplexer = multiplexer_net + self.pre_transform = pre_transform_block + self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)]) + self.add_noise = add_scalable_noise_to_transforms + self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3))) + + # And the switch itself, including learned scalars + self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None) + self.post_switch_conv = ConvBnLelu(out_filters, out_filters, norm=False, bias=True) + # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not) + # depending on its needs. + self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) + + def forward(self, x, passthrough, output_attention_weights=False, extra_arg=None): + identity = x + if self.add_noise: + rand_feature = torch.randn_like(x) * self.noise_scale + x = x + rand_feature + + x = self.pre_transform(x) + xformed = [t.forward(x, passthrough) for t in self.transforms] + m = self.multiplexer(identity, passthrough) + + + outputs, attention = self.switch(xformed, m, True) + outputs = self.post_switch_conv(outputs) + if output_attention_weights: + return outputs, attention + else: + return outputs + + def set_temperature(self, temp): + self.switch.set_attention_temperature(temp) + + +def create_switched_upsampler(nf, nf_out, num_channels, initial_temp=10): + multiplx = ExpandAndCollapse(nf, nf_out, num_channels) + pretransform = ConvGnLelu(nf, nf, norm=True, bias=False) + transform_fn = functools.partial(ExpansionBlock, nf, nf_out, block=ConvGnLelu) + return ConfigurableLinearSwitchComputer(nf_out, multiplx, + pre_transform_block=pretransform, transform_block=transform_fn, + attention_norm=True, + transform_count=num_channels, init_temp=initial_temp, + add_scalable_noise_to_transforms=False) + + +class Discriminator_switched(nn.Module): + def __init__(self, in_nc, nf, initial_temp=10, final_temperature_step=50000): + super(Discriminator_switched, self).__init__() + # [64, 128, 128] + self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False) + self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False) + # [64, 64, 64] + self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False) + self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False) + # [128, 32, 32] + self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False) + self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False) + # [256, 16, 16] + self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False) + self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) + # [512, 8, 8] + self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False) + self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) + + self.exp1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu) + self.upsw2 = create_switched_upsampler(nf * 8, nf * 4, 8) + self.upsw3 = create_switched_upsampler(nf * 4, nf * 2, 8) + self.switches = [self.upsw2, self.upsw3] + self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False) + self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) + + self.init_temperature = initial_temp + self.final_temperature_step = final_temperature_step + self.attentions = None + + def forward(self, x, flatten=True): + fea0 = self.conv0_0(x) + fea0 = self.conv0_1(fea0) + + fea1 = self.conv1_0(fea0) + fea1 = self.conv1_1(fea1) + + fea2 = self.conv2_0(fea1) + fea2 = self.conv2_1(fea2) + + fea3 = self.conv3_0(fea2) + fea3 = self.conv3_1(fea3) + + fea4 = self.conv4_0(fea3) + fea4 = self.conv4_1(fea4) + + u1 = self.exp1(fea4, fea3) + u2, a1 = self.upsw2(u1, fea2, output_attention_weights=True) + u3, a2 = self.upsw3(u2, fea1, output_attention_weights=True) + self.attentions = [a1, a2] + loss3 = self.collapse3(self.proc3(u3)) + return loss3.view(-1, 1) + + def pixgan_parameters(self): + return 1, 4 + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + for i, sw in enumerate(self.switches): + temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step + sw.set_temperature(min(self.init_temperature, + max(self.init_temperature - temp_loss_per_step * step, 1))) + if step % 50 == 0: + [save_attention_to_image(experiments_path, self.attentions[i], 8, step, "disc_a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))] + + def get_debug_values(self, step): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"disc_switch_temperature": temp} + for i in range(len(means)): + val["disc_switch_%i_specificity" % (i,)] = means[i] + val["disc_switch_%i_histogram" % (i,)] = hists[i] + return val + + class Discriminator_UNet_FeaOut(nn.Module): def __init__(self, in_nc, nf, feature_mode=False): super(Discriminator_UNet_FeaOut, self).__init__() diff --git a/codes/models/networks.py b/codes/models/networks.py index 84243f7e..d666cc94 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -124,6 +124,9 @@ def define_D(opt): netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf']) elif which_model == "discriminator_unet_fea": netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode']) + elif which_model == "discriminator_switched": + netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'], + final_temperature_step=opt_net['final_temperature_step']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/train.py b/codes/train.py index f879a0af..ca12db0f 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_progressive_srg2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_switched_disc.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -161,7 +161,7 @@ def main(): current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: - current_step = 0 + current_step = -1 start_epoch = 0 #### training diff --git a/codes/utils/convert_model.py b/codes/utils/convert_model.py index 6d31aac6..2f98ba8a 100644 --- a/codes/utils/convert_model.py +++ b/codes/utils/convert_model.py @@ -42,8 +42,6 @@ def copy_state_dict(dict_from, dict_to): if __name__ == "__main__": os.chdir("..") - torch.backends.cudnn.benchmark = True - want_just_images = True model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml") model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml")