More fixes

This commit is contained in:
James Betker 2020-07-08 22:00:57 -06:00
parent b2507be13c
commit 7d6eb28b87

View File

@ -246,7 +246,6 @@ class SRGANModel(BaseModel):
else: else:
gen_img = fake_GenOut gen_img = fake_GenOut
self.fake_GenOut.append(fake_GenOut.detach()) self.fake_GenOut.append(fake_GenOut.detach())
var_ref_skips.append(var_ref[0].detach())
l_g_total = 0 l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
@ -308,7 +307,7 @@ class SRGANModel(BaseModel):
noise = torch.randn_like(var_ref[0]) * noise_theta noise = torch.randn_like(var_ref[0]) * noise_theta
noise.to(self.device) noise.to(self.device)
self.optimizer_D.zero_grad() self.optimizer_D.zero_grad()
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix): for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
# Re-compute generator outputs (post-update). # Re-compute generator outputs (post-update).
with torch.no_grad(): with torch.no_grad():
fake_H = self.netG(var_L) fake_H = self.netG(var_L)
@ -320,9 +319,8 @@ class SRGANModel(BaseModel):
_t = time() _t = time()
# Apply noise to the inputs to slow discriminator convergence. # Apply noise to the inputs to slow discriminator convergence.
var_ref = (var_ref[0] + noise,) var_ref = (var_ref + noise,)
fake_H = (fake_H[0] + noise,) + fake_H[1:] fake_H = (fake_H[0] + noise,) + fake_H[1:]
self.fake_H.append(fake_H[0].detach())
if self.opt['train']['gan_type'] == 'gan': if self.opt['train']['gan_type'] == 'gan':
# need to forward and backward separately, since batch norm statistics differ # need to forward and backward separately, since batch norm statistics differ
# real # real
@ -351,9 +349,9 @@ class SRGANModel(BaseModel):
if random.random() > .25: if random.random() > .25:
# Make the swap across fake_H and var_ref # Make the swap across fake_H and var_ref
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1 SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION)
assert SWAP_MAX_DIM > 0 assert SWAP_MAX_DIM > 0
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
@ -406,8 +404,13 @@ class SRGANModel(BaseModel):
if _profile: if _profile:
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,)) print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
_t = time() _t = time()
# Append var_ref here, so that we can inspect the alterations the disc made if pixgan
var_ref_skips.append(var_ref[0].detach())
self.fake_H.append(fake_H[0].detach())
self.optimizer_D.step() self.optimizer_D.step()
if _profile: if _profile:
print("Disc step %f" % (time() - _t,)) print("Disc step %f" % (time() - _t,))
_t = time() _t = time()
@ -432,8 +435,8 @@ class SRGANModel(BaseModel):
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
if multi_gen: if multi_gen:
utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "%05i_%02i.png" % (step, i)))
else: else:
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))