Misc changes
This commit is contained in:
parent
5f2c722a10
commit
5e8b52f34c
|
@ -307,6 +307,8 @@ class SRGANModel(BaseModel):
|
||||||
noise = torch.randn_like(var_ref[0]) * noise_theta
|
noise = torch.randn_like(var_ref[0]) * noise_theta
|
||||||
noise.to(self.device)
|
noise.to(self.device)
|
||||||
self.optimizer_D.zero_grad()
|
self.optimizer_D.zero_grad()
|
||||||
|
real_disc_images = []
|
||||||
|
fake_disc_images = []
|
||||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||||
# Re-compute generator outputs (post-update).
|
# Re-compute generator outputs (post-update).
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -347,7 +349,6 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||||
if random.random() > .25:
|
if random.random() > .25:
|
||||||
|
|
||||||
# Make the swap across fake_H and var_ref
|
# Make the swap across fake_H and var_ref
|
||||||
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION)
|
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION)
|
||||||
assert SWAP_MAX_DIM > 0
|
assert SWAP_MAX_DIM > 0
|
||||||
|
@ -378,6 +379,14 @@ class SRGANModel(BaseModel):
|
||||||
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
||||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||||
l_d_fake_scaled.backward()
|
l_d_fake_scaled.backward()
|
||||||
|
|
||||||
|
pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real))
|
||||||
|
pdr = pdr / torch.max(pdr)
|
||||||
|
real_disc_images.append(pdr.view(disc_output_shape))
|
||||||
|
pdf = pred_d_fake.detach() + torch.abs(torch.min(pred_d_fake))
|
||||||
|
pdf = pdf / torch.max(pdf)
|
||||||
|
fake_disc_images.append(pdf.view(disc_output_shape))
|
||||||
|
|
||||||
elif self.opt['train']['gan_type'] == 'ragan':
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
pred_d_fake = self.netD(fake_H).detach()
|
pred_d_fake = self.netD(fake_H).detach()
|
||||||
pred_d_real = self.netD(var_ref)
|
pred_d_real = self.netD(var_ref)
|
||||||
|
@ -423,6 +432,7 @@ class SRGANModel(BaseModel):
|
||||||
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
||||||
os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True)
|
||||||
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True)
|
||||||
multi_gen = False
|
multi_gen = False
|
||||||
if isinstance(self.fake_GenOut[0], tuple):
|
if isinstance(self.fake_GenOut[0], tuple):
|
||||||
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
|
||||||
|
@ -435,9 +445,11 @@ class SRGANModel(BaseModel):
|
||||||
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
||||||
if multi_gen:
|
if multi_gen:
|
||||||
utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||||
if self.l_gan_w > 0 and step > self.G_warmup:
|
if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan':
|
||||||
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
||||||
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
|
||||||
|
utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
|
||||||
|
utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i)))
|
||||||
else:
|
else:
|
||||||
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||||
|
|
||||||
|
|
|
@ -230,7 +230,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
temp = 1 / temp
|
temp = 1 / temp
|
||||||
self.set_temperature(temp)
|
self.set_temperature(temp)
|
||||||
if step % 50 == 0:
|
if step % 50 == 0:
|
||||||
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))]
|
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,)) for i in range(len(self.switches))]
|
||||||
|
|
||||||
def get_debug_values(self, step):
|
def get_debug_values(self, step):
|
||||||
temp = self.switches[0].switch.temperature
|
temp = self.switches[0].switch.temperature
|
||||||
|
|
|
@ -92,14 +92,14 @@ class TorchCustomTrace:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/debug.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_div2k_pixgan_srg2.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
|
||||||
netG = define_G(opt)
|
netG = define_G(opt)
|
||||||
dummyInput = torch.rand(1,3,32,32)
|
dummyInput = torch.rand(1,3,32,32)
|
||||||
|
|
||||||
mode = 'memtrace'
|
mode = 'onnx'
|
||||||
if mode == 'torchscript':
|
if mode == 'torchscript':
|
||||||
print("Tracing generator network..")
|
print("Tracing generator network..")
|
||||||
traced_netG = torch.jit.trace(netG, dummyInput)
|
traced_netG = torch.jit.trace(netG, dummyInput)
|
Loading…
Reference in New Issue
Block a user