diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index d7987664..cc0094c6 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -307,6 +307,8 @@ class SRGANModel(BaseModel): noise = torch.randn_like(var_ref[0]) * noise_theta noise.to(self.device) self.optimizer_D.zero_grad() + real_disc_images = [] + fake_disc_images = [] for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): # Re-compute generator outputs (post-update). with torch.no_grad(): @@ -347,7 +349,6 @@ class SRGANModel(BaseModel): # randomly determine portions of the image to swap to keep the discriminator honest. if random.random() > .25: - # Make the swap across fake_H and var_ref SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) assert SWAP_MAX_DIM > 0 @@ -378,6 +379,14 @@ class SRGANModel(BaseModel): l_d_fake_log = l_d_fake * self.mega_batch_factor with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() + + pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real)) + pdr = pdr / torch.max(pdr) + real_disc_images.append(pdr.view(disc_output_shape)) + pdf = pred_d_fake.detach() + torch.abs(torch.min(pred_d_fake)) + pdf = pdf / torch.max(pdf) + fake_disc_images.append(pdf.view(disc_output_shape)) + elif self.opt['train']['gan_type'] == 'ragan': pred_d_fake = self.netD(fake_H).detach() pred_d_real = self.netD(var_ref) @@ -423,6 +432,7 @@ class SRGANModel(BaseModel): os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True) + os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True) multi_gen = False if isinstance(self.fake_GenOut[0], tuple): os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True) @@ -435,9 +445,11 @@ class SRGANModel(BaseModel): utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) if multi_gen: utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) - if self.l_gan_w > 0 and step > self.G_warmup: + if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan': utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) - utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "%05i_%02i.png" % (step, i))) + utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i))) + utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i))) + utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i))) else: utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index ec0f4fe1..8c89fd36 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -230,7 +230,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): temp = 1 / temp self.set_temperature(temp) if step % 50 == 0: - [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))] + [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,)) for i in range(len(self.switches))] def get_debug_values(self, step): temp = self.switches[0].switch.temperature diff --git a/codes/distill_torchscript.py b/codes/utils/distill_torchscript.py similarity index 98% rename from codes/distill_torchscript.py rename to codes/utils/distill_torchscript.py index 7a43195d..0f3aa173 100644 --- a/codes/distill_torchscript.py +++ b/codes/utils/distill_torchscript.py @@ -92,14 +92,14 @@ class TorchCustomTrace: if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/debug.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_div2k_pixgan_srg2.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) netG = define_G(opt) dummyInput = torch.rand(1,3,32,32) - mode = 'memtrace' + mode = 'onnx' if mode == 'torchscript': print("Tracing generator network..") traced_netG = torch.jit.trace(netG, dummyInput) diff --git a/codes/onnx_inference.py b/codes/utils/onnx_inference.py similarity index 100% rename from codes/onnx_inference.py rename to codes/utils/onnx_inference.py