From 3e7a83896be1b637bce422580d37e92f6953e341 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 16 Jul 2020 11:45:19 -0600 Subject: [PATCH] Fix pixgan debugging issues --- codes/models/SRGAN_model.py | 2 +- codes/models/archs/SwitchedResidualGenerator_arch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index a031e12d..309e46bf 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -464,7 +464,7 @@ class SRGANModel(BaseModel): utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i))) - if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan': + if self.l_gan_w > 0 and step > self.G_warmup and 'pixgan' in self.opt['train']['gan_type']: utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i))) utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i))) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index f545e3df..c9d13715 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -219,7 +219,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): temp = 1 / temp self.set_temperature(temp) if step % 50 == 0: - save_attention_to_image(experiments_path, self.attentions[0], self.transformation_counts, step, "a%i" % (1,), l_mult=10) + [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))] def get_debug_values(self, step): temp = self.switches[0].switch.temperature