diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index ad151c88..b97460a8 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -214,10 +214,8 @@ class SRGANModel(BaseModel): # the first element of the tuple. if isinstance(fake_GenOut, tuple): gen_img = fake_GenOut[0] - # TODO: Fix this. - self.fake_GenOut.append((fake_GenOut[0].detach(), - fake_GenOut[1].detach(), - fake_GenOut[2].detach())) + # The following line detaches all generator outputs that are not None. + self.fake_GenOut.append(tuple([(x.detach() if x is not None else None) for x in list(fake_GenOut)])) var_ref = (var_ref,) + self.create_artificial_skips(var_H) else: gen_img = fake_GenOut @@ -269,7 +267,8 @@ class SRGANModel(BaseModel): # Re-compute generator outputs (post-update). with torch.no_grad(): fake_H = self.netG(var_L) - fake_H = (fake_H[0].detach(), fake_H[1].detach(), fake_H[2].detach()) + # The following line detaches all generator outputs that are not None. + fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)]) # Apply noise to the inputs to slow discriminator convergence. var_ref = (var_ref[0] + noise,) + var_ref[1:] @@ -306,35 +305,38 @@ class SRGANModel(BaseModel): # Log sample images from first microbatch. if step % 50 == 0: - os.makedirs("temp/hr", exist_ok=True) - os.makedirs("temp/lr", exist_ok=True) - os.makedirs("temp/lr_precorrupt", exist_ok=True) - os.makedirs("temp/gen", exist_ok=True) - os.makedirs("temp/pix", exist_ok=True) + sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp") + os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "lr_precorrupt"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True) multi_gen = False if isinstance(self.fake_GenOut[0], tuple): - os.makedirs("temp/genlr", exist_ok=True) - os.makedirs("temp/genmr", exist_ok=True) - os.makedirs("temp/ref", exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "genlr"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "genmr"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True) multi_gen = True # fed_LQ is not chunked. - utils.save_image(self.fed_LQ.cpu().detach(), os.path.join("temp/lr_precorrupt", "%05i.png" % (step,))) + utils.save_image(self.fed_LQ.cpu().detach(), os.path.join(sample_save_path, "lr_precorrupt", "%05i.png" % (step,))) for i in range(self.mega_batch_factor): - utils.save_image(self.var_H[i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i))) - utils.save_image(self.var_L[i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i))) - utils.save_image(self.pix[i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i))) + utils.save_image(self.var_H[i].cpu().detach(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.var_L[i].cpu().detach(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.pix[i].cpu().detach(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) if multi_gen: - utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i))) - utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join("temp/genmr", "%05i_%02i.png" % (step, i))) - utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join("temp/genlr", "%05i_%02i.png" % (step, i))) - utils.save_image(var_ref_skips[i][0].cpu().detach(), os.path.join("temp/ref", "hi_%05i_%02i.png" % (step, i))) - utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join("temp/ref", "med_%05i_%02i.png" % (step, i))) - utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join("temp/ref", "low_%05i_%02i.png" % (step, i))) + utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) + if self.fake_GenOut[i][1] is not None: + utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join(sample_save_path, "genmr", "%05i_%02i.png" % (step, i))) + if self.fake_GenOut[i][2] is not None: + utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join(sample_save_path, "genlr", "%05i_%02i.png" % (step, i))) + utils.save_image(var_ref_skips[i][0].cpu().detach(), os.path.join(sample_save_path, "ref", "hi_%05i_%02i.png" % (step, i))) + utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join(sample_save_path, "ref", "med_%05i_%02i.png" % (step, i))) + utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join(sample_save_path, "ref", "low_%05i_%02i.png" % (step, i))) else: - utils.save_image(self.fake_GenOut[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fake_GenOut[i].cpu().detach(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) - # set log TODO(handle mega-batches?) + # Log metrics if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.add_log_entry('l_g_pix', l_g_pix.item()) @@ -346,6 +348,7 @@ class SRGANModel(BaseModel): self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) + self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) self.add_log_entry('noise_theta', noise_theta) if step % self.corruptor_swapout_steps == 0 and step > 0: @@ -398,13 +401,13 @@ class SRGANModel(BaseModel): self.swapout_G_duration -= 1 if self.swapout_G_duration == 0: # Swap back. - print("Swapping back to current G model: %s" % (self.stashed_G,)) + logger.info("Swapping back to current G model: %s" % (self.stashed_G,)) self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load']) self.stashed_G = None elif self.swapout_G_freq != 0 and step % self.swapout_G_freq == 0: swapped_model = self.pick_rand_prev_model('G') if swapped_model is not None: - print("Swapping to previous G model: %s" % (swapped_model,)) + logger.info("Swapping to previous G model: %s" % (swapped_model,)) self.stashed_G = self.save_network(self.netG, 'G', 'swap_model') self.load_network(swapped_model, self.netG, self.opt['path']['strict_load']) self.swapout_G_duration = self.swapout_duration diff --git a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py index 34729080..521f9c8b 100644 --- a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py +++ b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py @@ -107,15 +107,18 @@ class FixupBottleneck(nn.Module): class FixupResNet(nn.Module): - def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, use_bn=False): + def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False): super(FixupResNet, self).__init__() self.num_layers = sum(layers) self.inplanes = 3 + self.number_skips = number_skips self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5) - self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator. + if number_skips > 0: + self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator. self.layer1 = self._make_layer(block, num_filters*4, layers[1], stride=2, use_bn=use_bn, conv_type=conv5x5) - self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator. + if number_skips > 1: + self.inplanes = self.inplanes + 3 # Accomodate a second skip connection from the generator. self.layer2 = self._make_layer(block, num_filters*8, layers[2], stride=2, use_bn=use_bn) # SRGAN already has a feature loss tied to a separate VGG discriminator. We really don't care about features. # Therefore, level off the filter count from this block forwards. @@ -157,9 +160,11 @@ class FixupResNet(nn.Module): x, med_skip, lo_skip = x x = self.layer0(x) - x = torch.cat([x, med_skip], dim=1) + if self.number_skips > 0: + x = torch.cat([x, med_skip], dim=1) x = self.layer1(x) - x = torch.cat([x, lo_skip], dim=1) + if self.number_skips > 1: + x = torch.cat([x, lo_skip], dim=1) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) diff --git a/codes/models/archs/ResGen_arch.py b/codes/models/archs/ResGen_arch.py index d4352e00..ddaa3f36 100644 --- a/codes/models/archs/ResGen_arch.py +++ b/codes/models/archs/ResGen_arch.py @@ -177,13 +177,23 @@ class FixupResNetV2(FixupResNet): if self.upscale_applications > 0: x = F.interpolate(x, scale_factor=2.0, mode='nearest') x = self.layer2(x) - skip_med = self.filter_to_image(x) + skip_med = self.filter_to_image(x) if self.upscale_applications > 1: x = F.interpolate(x, scale_factor=2.0, mode='nearest') x = self.layer2(x) - x = self.filter_to_image(x) + if self.upscale_applications == 2: + x = self.filter_to_image(x) + elif self.upscale_applications == 1: + x = skip_med + skip_med = skip_lo + skip_lo = None + elif self.upscale_applications == 0: + x = skip_lo + skip_lo = None + skip_med = None + return x, skip_med, skip_lo def fixup_resnet34(nb_denoiser=20, nb_upsampler=10, **kwargs): diff --git a/codes/models/networks.py b/codes/models/networks.py index 1bf95aa0..efe9ed1c 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -75,7 +75,8 @@ def define_D(opt): elif which_model == 'discriminator_resnet': netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) elif which_model == 'discriminator_resnet_passthrough': - netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, use_bn=True) + netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, + number_skips=opt_net['number_skips'], use_bn=True) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/train.py b/codes/train.py index 2bcfd77d..8281db8c 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_vix_resgenv2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vix_resgenv2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -201,6 +201,7 @@ def main(): model.test() visuals = model.get_current_visuals() + sr_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8