Misc changes

This commit is contained in:
James Betker 2020-07-10 09:45:34 -06:00
parent 5f2c722a10
commit 5e8b52f34c
4 changed files with 18 additions and 6 deletions

View File

@ -307,6 +307,8 @@ class SRGANModel(BaseModel):
noise = torch.randn_like(var_ref[0]) * noise_theta noise = torch.randn_like(var_ref[0]) * noise_theta
noise.to(self.device) noise.to(self.device)
self.optimizer_D.zero_grad() self.optimizer_D.zero_grad()
real_disc_images = []
fake_disc_images = []
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
# Re-compute generator outputs (post-update). # Re-compute generator outputs (post-update).
with torch.no_grad(): with torch.no_grad():
@ -347,7 +349,6 @@ class SRGANModel(BaseModel):
# randomly determine portions of the image to swap to keep the discriminator honest. # randomly determine portions of the image to swap to keep the discriminator honest.
if random.random() > .25: if random.random() > .25:
# Make the swap across fake_H and var_ref # Make the swap across fake_H and var_ref
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION)
assert SWAP_MAX_DIM > 0 assert SWAP_MAX_DIM > 0
@ -378,6 +379,14 @@ class SRGANModel(BaseModel):
l_d_fake_log = l_d_fake * self.mega_batch_factor l_d_fake_log = l_d_fake * self.mega_batch_factor
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real))
pdr = pdr / torch.max(pdr)
real_disc_images.append(pdr.view(disc_output_shape))
pdf = pred_d_fake.detach() + torch.abs(torch.min(pred_d_fake))
pdf = pdf / torch.max(pdf)
fake_disc_images.append(pdf.view(disc_output_shape))
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
pred_d_fake = self.netD(fake_H).detach() pred_d_fake = self.netD(fake_H).detach()
pred_d_real = self.netD(var_ref) pred_d_real = self.netD(var_ref)
@ -423,6 +432,7 @@ class SRGANModel(BaseModel):
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True)
multi_gen = False multi_gen = False
if isinstance(self.fake_GenOut[0], tuple): if isinstance(self.fake_GenOut[0], tuple):
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
@ -435,9 +445,11 @@ class SRGANModel(BaseModel):
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
if multi_gen: if multi_gen:
utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
if self.l_gan_w > 0 and step > self.G_warmup: if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan':
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i)))
else: else:
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))

View File

@ -230,7 +230,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
temp = 1 / temp temp = 1 / temp
self.set_temperature(temp) self.set_temperature(temp)
if step % 50 == 0: if step % 50 == 0:
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))] [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,)) for i in range(len(self.switches))]
def get_debug_values(self, step): def get_debug_values(self, step):
temp = self.switches[0].switch.temperature temp = self.switches[0].switch.temperature

View File

@ -92,14 +92,14 @@ class TorchCustomTrace:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/debug.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_div2k_pixgan_srg2.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
netG = define_G(opt) netG = define_G(opt)
dummyInput = torch.rand(1,3,32,32) dummyInput = torch.rand(1,3,32,32)
mode = 'memtrace' mode = 'onnx'
if mode == 'torchscript': if mode == 'torchscript':
print("Tracing generator network..") print("Tracing generator network..")
traced_netG = torch.jit.trace(netG, dummyInput) traced_netG = torch.jit.trace(netG, dummyInput)