Misc changes
This commit is contained in:
parent
5f2c722a10
commit
5e8b52f34c
|
@ -307,6 +307,8 @@ class SRGANModel(BaseModel):
|
|||
noise = torch.randn_like(var_ref[0]) * noise_theta
|
||||
noise.to(self.device)
|
||||
self.optimizer_D.zero_grad()
|
||||
real_disc_images = []
|
||||
fake_disc_images = []
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||
# Re-compute generator outputs (post-update).
|
||||
with torch.no_grad():
|
||||
|
@ -347,7 +349,6 @@ class SRGANModel(BaseModel):
|
|||
|
||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||
if random.random() > .25:
|
||||
|
||||
# Make the swap across fake_H and var_ref
|
||||
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION)
|
||||
assert SWAP_MAX_DIM > 0
|
||||
|
@ -378,6 +379,14 @@ class SRGANModel(BaseModel):
|
|||
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
|
||||
pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real))
|
||||
pdr = pdr / torch.max(pdr)
|
||||
real_disc_images.append(pdr.view(disc_output_shape))
|
||||
pdf = pred_d_fake.detach() + torch.abs(torch.min(pred_d_fake))
|
||||
pdf = pdf / torch.max(pdf)
|
||||
fake_disc_images.append(pdf.view(disc_output_shape))
|
||||
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_d_fake = self.netD(fake_H).detach()
|
||||
pred_d_real = self.netD(var_ref)
|
||||
|
@ -423,6 +432,7 @@ class SRGANModel(BaseModel):
|
|||
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True)
|
||||
multi_gen = False
|
||||
if isinstance(self.fake_GenOut[0], tuple):
|
||||
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
|
||||
|
@ -435,9 +445,11 @@ class SRGANModel(BaseModel):
|
|||
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
||||
if multi_gen:
|
||||
utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||
if self.l_gan_w > 0 and step > self.G_warmup:
|
||||
if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan':
|
||||
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i)))
|
||||
else:
|
||||
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||
|
||||
|
|
|
@ -230,7 +230,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
temp = 1 / temp
|
||||
self.set_temperature(temp)
|
||||
if step % 50 == 0:
|
||||
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))]
|
||||
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,)) for i in range(len(self.switches))]
|
||||
|
||||
def get_debug_values(self, step):
|
||||
temp = self.switches[0].switch.temperature
|
||||
|
|
|
@ -92,14 +92,14 @@ class TorchCustomTrace:
|
|||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/debug.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_div2k_pixgan_srg2.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
netG = define_G(opt)
|
||||
dummyInput = torch.rand(1,3,32,32)
|
||||
|
||||
mode = 'memtrace'
|
||||
mode = 'onnx'
|
||||
if mode == 'torchscript':
|
||||
print("Tracing generator network..")
|
||||
traced_netG = torch.jit.trace(netG, dummyInput)
|
Loading…
Reference in New Issue
Block a user