forked from mrq/DL-Art-School
More fixes
This commit is contained in:
parent
b2507be13c
commit
7d6eb28b87
|
@ -246,7 +246,6 @@ class SRGANModel(BaseModel):
|
||||||
else:
|
else:
|
||||||
gen_img = fake_GenOut
|
gen_img = fake_GenOut
|
||||||
self.fake_GenOut.append(fake_GenOut.detach())
|
self.fake_GenOut.append(fake_GenOut.detach())
|
||||||
var_ref_skips.append(var_ref[0].detach())
|
|
||||||
|
|
||||||
l_g_total = 0
|
l_g_total = 0
|
||||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||||
|
@ -308,7 +307,7 @@ class SRGANModel(BaseModel):
|
||||||
noise = torch.randn_like(var_ref[0]) * noise_theta
|
noise = torch.randn_like(var_ref[0]) * noise_theta
|
||||||
noise.to(self.device)
|
noise.to(self.device)
|
||||||
self.optimizer_D.zero_grad()
|
self.optimizer_D.zero_grad()
|
||||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||||
# Re-compute generator outputs (post-update).
|
# Re-compute generator outputs (post-update).
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
fake_H = self.netG(var_L)
|
fake_H = self.netG(var_L)
|
||||||
|
@ -320,9 +319,8 @@ class SRGANModel(BaseModel):
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
# Apply noise to the inputs to slow discriminator convergence.
|
# Apply noise to the inputs to slow discriminator convergence.
|
||||||
var_ref = (var_ref[0] + noise,)
|
var_ref = (var_ref + noise,)
|
||||||
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
||||||
self.fake_H.append(fake_H[0].detach())
|
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
# need to forward and backward separately, since batch norm statistics differ
|
# need to forward and backward separately, since batch norm statistics differ
|
||||||
# real
|
# real
|
||||||
|
@ -351,9 +349,9 @@ class SRGANModel(BaseModel):
|
||||||
if random.random() > .25:
|
if random.random() > .25:
|
||||||
|
|
||||||
# Make the swap across fake_H and var_ref
|
# Make the swap across fake_H and var_ref
|
||||||
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1
|
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION)
|
||||||
assert SWAP_MAX_DIM > 0
|
assert SWAP_MAX_DIM > 0
|
||||||
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION
|
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
||||||
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
||||||
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
||||||
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
||||||
|
@ -406,8 +404,13 @@ class SRGANModel(BaseModel):
|
||||||
if _profile:
|
if _profile:
|
||||||
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
|
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
|
# Append var_ref here, so that we can inspect the alterations the disc made if pixgan
|
||||||
|
var_ref_skips.append(var_ref[0].detach())
|
||||||
|
self.fake_H.append(fake_H[0].detach())
|
||||||
self.optimizer_D.step()
|
self.optimizer_D.step()
|
||||||
|
|
||||||
|
|
||||||
if _profile:
|
if _profile:
|
||||||
print("Disc step %f" % (time() - _t,))
|
print("Disc step %f" % (time() - _t,))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
@ -432,8 +435,8 @@ class SRGANModel(BaseModel):
|
||||||
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
||||||
if multi_gen:
|
if multi_gen:
|
||||||
utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||||
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
|
||||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||||
|
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
||||||
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "%05i_%02i.png" % (step, i)))
|
||||||
else:
|
else:
|
||||||
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user