From 1b1431133bf397ee9a2b21373995dcdd26eb7533 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 14 Jul 2020 09:28:24 -0600 Subject: [PATCH] Add DualOutputSRG Also removes the old multi-return mechanism that Generators support. Also fixes AttentionNorm. --- codes/models/SRGAN_model.py | 93 ++++++++------- .../archs/SwitchedResidualGenerator_arch.py | 107 +++++++++++++++++- codes/models/archs/discriminator_vgg_arch.py | 3 - codes/models/networks.py | 9 ++ codes/train.py | 2 +- codes/utils/numeric_stability.py | 2 +- 6 files changed, 161 insertions(+), 55 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index f95bbfdb..a47d007d 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -70,9 +70,11 @@ class SRGANModel(BaseModel): else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] - self.l_fea_w_decay = train_opt['feature_weight_decay'] + self.l_fea_w_decay_start = train_opt['feature_weight_decay_start'] self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps'] self.l_fea_w_minimum = train_opt['feature_weight_minimum'] + if self.l_fea_w_decay_start: + self.l_fea_w_decay_step_size = (self.l_fea_w - self.l_fea_w_minimum) / (self.l_fea_w_decay_steps) else: logger.info('Remove feature loss.') self.cri_fea = None @@ -202,16 +204,17 @@ class SRGANModel(BaseModel): for p in self.netD.parameters(): p.requires_grad = False - if step > self.D_init_iters: + if step >= self.D_init_iters: self.optimizer_G.zero_grad() self.swapout_D(step) self.swapout_G(step) # Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason. - if step % self.D_update_ratio == 0 and step > self.D_init_iters: + if step % self.D_update_ratio == 0 and step >= self.D_init_iters: for p in self.netG.parameters(): - p.requires_grad = True + if p.dtype != torch.int64 and p.dtype != torch.bool: + p.requires_grad = True else: for p in self.netG.parameters(): p.requires_grad = False @@ -227,35 +230,28 @@ class SRGANModel(BaseModel): _t = time() self.fake_GenOut = [] + self.fea_GenOut = [] self.fake_H = [] var_ref_skips = [] for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): - fake_GenOut = self.netG(var_L) + fea_GenOut, fake_GenOut = self.netG(var_L) if _profile: print("Gen forward %f" % (time() - _t,)) _t = time() - # Extract the image output. For generators that output skip-through connections, the master output is always - # the first element of the tuple. - if isinstance(fake_GenOut, tuple): - gen_img = fake_GenOut[0] - # The following line detaches all generator outputs that are not None. - self.fake_GenOut.append(tuple([(x.detach() if x is not None else None) for x in list(fake_GenOut)])) - var_ref = (var_ref,) # This is a tuple for legacy reasons. - else: - gen_img = fake_GenOut - self.fake_GenOut.append(fake_GenOut.detach()) + self.fake_GenOut.append(fake_GenOut.detach()) + self.fea_GenOut.append(fea_GenOut.detach()) l_g_total = 0 - if step % self.D_update_ratio == 0 and step > self.D_init_iters: + if step % self.D_update_ratio == 0 and step >= self.D_init_iters: if self.cri_pix: # pixel loss - l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix) + l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix) l_g_pix_log = l_g_pix / self.l_pix_w l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(pix).detach() - fake_fea = self.netF(gen_img) + fake_fea = self.netF(fea_GenOut) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_fea_log = l_g_fea / self.l_fea_w l_g_total += l_g_fea @@ -266,8 +262,13 @@ class SRGANModel(BaseModel): # Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role # in the resultant image. - if step % self.l_fea_w_decay_steps == 0: - self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) + if self.l_fea_w_decay_start and step > self.l_fea_w_decay_start: + self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w - self.l_fea_w_decay_step_size * (step - self.l_fea_w_decay_start)) + + # Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931 + # Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is + # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically, + # it should target this value. if self.l_gan_w > 0: if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan': @@ -304,7 +305,7 @@ class SRGANModel(BaseModel): for p in self.netD.parameters(): p.requires_grad = True - noise = torch.randn_like(var_ref[0]) * noise_theta + noise = torch.randn_like(var_ref) * noise_theta noise.to(self.device) self.optimizer_D.zero_grad() real_disc_images = [] @@ -312,17 +313,17 @@ class SRGANModel(BaseModel): for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): # Re-compute generator outputs (post-update). with torch.no_grad(): - fake_H = self.netG(var_L) + _, fake_H = self.netG(var_L) # The following line detaches all generator outputs that are not None. - fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)]) + fake_H = fake_H.detach() if _profile: print("Gen forward for disc %f" % (time() - _t,)) _t = time() # Apply noise to the inputs to slow discriminator convergence. - var_ref = (var_ref + noise,) - fake_H = (fake_H[0] + noise,) + fake_H[1:] + var_ref = var_ref + noise + fake_H = fake_H + noise if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ # real @@ -340,10 +341,10 @@ class SRGANModel(BaseModel): if self.opt['train']['gan_type'] == 'pixgan': # randomly determine portions of the image to swap to keep the discriminator honest. pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() - disc_output_shape = (var_ref[0].shape[0], pixdisc_channels, var_ref[0].shape[2] // pixdisc_output_reduction, var_ref[0].shape[3] // pixdisc_output_reduction) - b, _, w, h = var_ref[0].shape - real = torch.ones((b, pixdisc_channels, w, h), device=var_ref[0].device) - fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref[0].device) + disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) + b, _, w, h = var_ref.shape + real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) + fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device) SWAP_MAX_DIM = w // 4 SWAP_MIN_DIM = 16 assert SWAP_MAX_DIM > 0 @@ -360,9 +361,9 @@ class SRGANModel(BaseModel): swap_w = w - swap_x if swap_y + swap_h > h: swap_h = h - swap_y - t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() - fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] - var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t + t = fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() + fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] + var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 @@ -422,8 +423,8 @@ class SRGANModel(BaseModel): _t = time() # Append var_ref here, so that we can inspect the alterations the disc made if pixgan - var_ref_skips.append(var_ref[0].detach()) - self.fake_H.append(fake_H[0].detach()) + var_ref_skips.append(var_ref.detach()) + self.fake_H.append(fake_H.detach()) self.optimizer_D.step() @@ -436,32 +437,28 @@ class SRGANModel(BaseModel): sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp") os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "gen_fea"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True) - multi_gen = False - if isinstance(self.fake_GenOut[0], tuple): - os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True) - multi_gen = True + os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True) # fed_LQ is not chunked. for i in range(self.mega_batch_factor): utils.save_image(self.var_H[i].cpu(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i))) utils.save_image(self.var_L[i].cpu(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i))) utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) - if multi_gen: - utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) - if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan': - utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) - utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i))) - utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i))) - utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i))) - else: - utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i))) + if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan': + utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i))) + utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i))) + utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i))) # Log metrics - if step % self.D_update_ratio == 0 and step > self.D_init_iters: + if step % self.D_update_ratio == 0 and step >= self.D_init_iters: if self.cri_pix: self.add_log_entry('l_g_pix', l_g_pix_log.item()) if self.cri_fea: diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 0cde385e..e42ec2a0 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -196,7 +196,8 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): if self.upsample_factor > 2: x = F.interpolate(x, scale_factor=2, mode="nearest") x = self.upconv2(x) - return self.final_conv(self.hr_conv(x)), + x = self.final_conv(self.hr_conv(x)) + return x, x def set_temperature(self, temp): [sw.set_temperature(temp) for sw in self.switches] @@ -318,4 +319,106 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module): for i in range(len(means)): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] - return val \ No newline at end of file + return val + + +class DualOutputSRG(nn.Module): + def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, + trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, + heightened_final_step=50000, upsample_factor=1, + add_scalable_noise_to_transforms=False): + super(DualOutputSRG, self).__init__() + switches = [] + self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True) + + self.fea_upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + self.fea_upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + self.fea_hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + self.fea_final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) + + self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) + + for _ in range(switch_depth): + multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts) + pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) + transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1) + switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + transform_count=trans_counts, init_temp=initial_temp, + add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) + + self.switches = nn.ModuleList(switches) + self.transformation_counts = trans_counts + self.init_temperature = initial_temp + self.final_temperature_step = final_temperature_step + self.heightened_temp_min = heightened_temp_min + self.heightened_final_step = heightened_final_step + self.attentions = None + self.upsample_factor = upsample_factor + assert self.upsample_factor == 2 or self.upsample_factor == 4 + + def forward(self, x): + x = self.initial_conv(x) + + self.attentions = [] + for i, sw in enumerate(self.switches): + x, att = sw.forward(x, True) + self.attentions.append(att) + + if i == len(self.switches)-2: + fea = self.fea_upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) + if self.upsample_factor > 2: + fea = F.interpolate(fea, scale_factor=2, mode="nearest") + fea = self.fea_upconv2(fea) + fea = self.fea_final_conv(self.hr_conv(fea)) + + x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) + if self.upsample_factor > 2: + x = F.interpolate(x, scale_factor=2, mode="nearest") + x = self.upconv2(x) + return fea, self.final_conv(self.hr_conv(x)) + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step)) + if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1: + # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. + # without this, the attention specificity "spikes" incredibly fast in the last few iterations. + h_steps_total = self.heightened_final_step - self.final_temperature_step + h_steps_current = max(min(step - self.final_temperature_step, h_steps_total), 1) + # The "gap" will represent the steps that need to be traveled as a linear function. + h_gap = 1 / self.heightened_temp_min + temp = h_gap * h_steps_current / h_steps_total + # Invert temperature to represent reality on this side of the curve + temp = 1 / temp + self.set_temperature(temp) + if step % 50 == 0: + [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,)) for i in range(len(self.switches))] + + def get_debug_values(self, step): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val + + + def load_state_dict(self, state_dict, strict=True): + # Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict + t_state = self.state_dict() + if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys(): + for i in range(4): + state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)] + state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)] + state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] + super(DualOutputSRG, self).load_state_dict(state_dict, strict) \ No newline at end of file diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 559c2f16..e20f7e9f 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -51,7 +51,6 @@ class Discriminator_VGG_128(nn.Module): self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): - x = x[0] fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) @@ -127,7 +126,6 @@ class Discriminator_VGG_PixLoss(nn.Module): self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x, flatten=True): - x = x[0] fea0 = self.lrelu(self.conv0_0(x)) fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0))) @@ -205,7 +203,6 @@ class Discriminator_UNet(nn.Module): self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False) def forward(self, x, flatten=True): - x = x[0] fea0 = self.conv0_0(x) fea0 = self.conv0_1(fea0) diff --git a/codes/models/networks.py b/codes/models/networks.py index 7507add2..2dfe0af0 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -78,6 +78,15 @@ def define_G(opt, net_key='network_G'): initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) + elif which_model == "DualOutputSRG": + netG = SwitchedGen_arch.DualOutputSRG(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'], + switch_reductions=opt_net['switch_reductions'], + switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], + trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], + transformation_filters=opt_net['transformation_filters'], + initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], + heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], + upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) # image corruption elif which_model == 'HighToLowResNet': diff --git a/codes/train.py b/codes/train.py index 4744ce26..65395714 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_dual_srg.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) diff --git a/codes/utils/numeric_stability.py b/codes/utils/numeric_stability.py index bfcedc70..dda9f79b 100644 --- a/codes/utils/numeric_stability.py +++ b/codes/utils/numeric_stability.py @@ -93,7 +93,7 @@ if __name__ == "__main__": torch.randn(1, 3, 64, 64), device='cuda') ''' - test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2, + test_stability(functools.partial(srg.DualOutputSRG, switch_depth=4, switch_filters=64, switch_reductions=4,