diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 60baf25b..a1b2d5d4 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -116,9 +116,16 @@ class SRGANModel(BaseModel): weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) + optim_params = [] + for k, v in self.netD.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 - self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], + self.optimizer_D = torch.optim.Adam(optim_params, lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) @@ -219,6 +226,8 @@ class SRGANModel(BaseModel): # Some generators have variants depending on the current step. if hasattr(self.netG.module, "update_for_step"): self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) + if hasattr(self.netD.module, "update_for_step"): + self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) # G for p in self.netD.parameters(): @@ -323,7 +332,8 @@ class SRGANModel(BaseModel): # D if self.l_gan_w > 0 and step > self.G_warmup: for p in self.netD.parameters(): - p.requires_grad = True + if p.dtype != torch.int64 and p.dtype != torch.bool: + p.requires_grad = True noise = torch.randn_like(var_ref) * noise_theta noise.to(self.device) @@ -610,6 +620,8 @@ class SRGANModel(BaseModel): # Some generators can do their own metric logging. if hasattr(self.netG.module, "get_debug_values"): return_log.update(self.netG.module.get_debug_values(step)) + if hasattr(self.netD.module, "get_debug_values"): + return_log.update(self.netD.module.get_debug_values(step)) return return_log diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 5dc8df23..04c9fc4d 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -238,6 +238,155 @@ class Discriminator_UNet(nn.Module): return 3, 4 +import functools +from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConfigurableSwitchComputer, BareConvSwitch +from switched_conv_util import save_attention_to_image +from switched_conv import compute_attention_specificity, AttentionNorm + + +class ExpandAndCollapse(nn.Module): + def __init__(self, nf, nf_out, num_channels): + super(ExpandAndCollapse, self).__init__() + self.expand = ExpansionBlock(nf, nf_out, block=ConvGnLelu) + self.collapse = ConvGnLelu(nf_out, num_channels, norm=False, bias=False, activation=False) + + def forward(self, x, passthrough): + x = self.expand(x, passthrough) + return self.collapse(x) + + +# Differs from ConfigurableSwitchComputer in that the connections are not residual and the multiplexer is fed directly in. +class ConfigurableLinearSwitchComputer(nn.Module): + def __init__(self, out_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm, + init_temp=20, add_scalable_noise_to_transforms=False): + super(ConfigurableLinearSwitchComputer, self).__init__() + + self.multiplexer = multiplexer_net + self.pre_transform = pre_transform_block + self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)]) + self.add_noise = add_scalable_noise_to_transforms + self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3))) + + # And the switch itself, including learned scalars + self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None) + self.post_switch_conv = ConvBnLelu(out_filters, out_filters, norm=False, bias=True) + # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not) + # depending on its needs. + self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) + + def forward(self, x, passthrough, output_attention_weights=False, extra_arg=None): + identity = x + if self.add_noise: + rand_feature = torch.randn_like(x) * self.noise_scale + x = x + rand_feature + + x = self.pre_transform(x) + xformed = [t.forward(x, passthrough) for t in self.transforms] + m = self.multiplexer(identity, passthrough) + + + outputs, attention = self.switch(xformed, m, True) + outputs = self.post_switch_conv(outputs) + if output_attention_weights: + return outputs, attention + else: + return outputs + + def set_temperature(self, temp): + self.switch.set_attention_temperature(temp) + + +def create_switched_upsampler(nf, nf_out, num_channels, initial_temp=10): + multiplx = ExpandAndCollapse(nf, nf_out, num_channels) + pretransform = ConvGnLelu(nf, nf, norm=True, bias=False) + transform_fn = functools.partial(ExpansionBlock, nf, nf_out, block=ConvGnLelu) + return ConfigurableLinearSwitchComputer(nf_out, multiplx, + pre_transform_block=pretransform, transform_block=transform_fn, + attention_norm=True, + transform_count=num_channels, init_temp=initial_temp, + add_scalable_noise_to_transforms=False) + + +class Discriminator_switched(nn.Module): + def __init__(self, in_nc, nf, initial_temp=10, final_temperature_step=50000): + super(Discriminator_switched, self).__init__() + # [64, 128, 128] + self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False) + self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False) + # [64, 64, 64] + self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False) + self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False) + # [128, 32, 32] + self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False) + self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False) + # [256, 16, 16] + self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False) + self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) + # [512, 8, 8] + self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False) + self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False) + + self.exp1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu) + self.upsw2 = create_switched_upsampler(nf * 8, nf * 4, 8) + self.upsw3 = create_switched_upsampler(nf * 4, nf * 2, 8) + self.switches = [self.upsw2, self.upsw3] + self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False) + self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) + + self.init_temperature = initial_temp + self.final_temperature_step = final_temperature_step + self.attentions = None + + def forward(self, x, flatten=True): + fea0 = self.conv0_0(x) + fea0 = self.conv0_1(fea0) + + fea1 = self.conv1_0(fea0) + fea1 = self.conv1_1(fea1) + + fea2 = self.conv2_0(fea1) + fea2 = self.conv2_1(fea2) + + fea3 = self.conv3_0(fea2) + fea3 = self.conv3_1(fea3) + + fea4 = self.conv4_0(fea3) + fea4 = self.conv4_1(fea4) + + u1 = self.exp1(fea4, fea3) + u2, a1 = self.upsw2(u1, fea2, output_attention_weights=True) + u3, a2 = self.upsw3(u2, fea1, output_attention_weights=True) + self.attentions = [a1, a2] + loss3 = self.collapse3(self.proc3(u3)) + return loss3.view(-1, 1) + + def pixgan_parameters(self): + return 1, 4 + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + for i, sw in enumerate(self.switches): + temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step + sw.set_temperature(min(self.init_temperature, + max(self.init_temperature - temp_loss_per_step * step, 1))) + if step % 50 == 0: + [save_attention_to_image(experiments_path, self.attentions[i], 8, step, "disc_a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))] + + def get_debug_values(self, step): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"disc_switch_temperature": temp} + for i in range(len(means)): + val["disc_switch_%i_specificity" % (i,)] = means[i] + val["disc_switch_%i_histogram" % (i,)] = hists[i] + return val + + class Discriminator_UNet_FeaOut(nn.Module): def __init__(self, in_nc, nf, feature_mode=False): super(Discriminator_UNet_FeaOut, self).__init__() diff --git a/codes/models/networks.py b/codes/models/networks.py index 84243f7e..d666cc94 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -124,6 +124,9 @@ def define_D(opt): netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf']) elif which_model == "discriminator_unet_fea": netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode']) + elif which_model == "discriminator_switched": + netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'], + final_temperature_step=opt_net['final_temperature_step']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/train.py b/codes/train.py index f879a0af..ca12db0f 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_progressive_srg2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_switched_disc.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -161,7 +161,7 @@ def main(): current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: - current_step = 0 + current_step = -1 start_epoch = 0 #### training diff --git a/codes/utils/convert_model.py b/codes/utils/convert_model.py index 6d31aac6..2f98ba8a 100644 --- a/codes/utils/convert_model.py +++ b/codes/utils/convert_model.py @@ -42,8 +42,6 @@ def copy_state_dict(dict_from, dict_to): if __name__ == "__main__": os.chdir("..") - torch.backends.cudnn.benchmark = True - want_just_images = True model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml") model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml")