forked from mrq/DL-Art-School
More fixes
This commit is contained in:
parent
b2507be13c
commit
7d6eb28b87
|
@ -246,7 +246,6 @@ class SRGANModel(BaseModel):
|
|||
else:
|
||||
gen_img = fake_GenOut
|
||||
self.fake_GenOut.append(fake_GenOut.detach())
|
||||
var_ref_skips.append(var_ref[0].detach())
|
||||
|
||||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
|
@ -308,7 +307,7 @@ class SRGANModel(BaseModel):
|
|||
noise = torch.randn_like(var_ref[0]) * noise_theta
|
||||
noise.to(self.device)
|
||||
self.optimizer_D.zero_grad()
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||
# Re-compute generator outputs (post-update).
|
||||
with torch.no_grad():
|
||||
fake_H = self.netG(var_L)
|
||||
|
@ -320,9 +319,8 @@ class SRGANModel(BaseModel):
|
|||
_t = time()
|
||||
|
||||
# Apply noise to the inputs to slow discriminator convergence.
|
||||
var_ref = (var_ref[0] + noise,)
|
||||
var_ref = (var_ref + noise,)
|
||||
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
||||
self.fake_H.append(fake_H[0].detach())
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
# need to forward and backward separately, since batch norm statistics differ
|
||||
# real
|
||||
|
@ -351,9 +349,9 @@ class SRGANModel(BaseModel):
|
|||
if random.random() > .25:
|
||||
|
||||
# Make the swap across fake_H and var_ref
|
||||
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1
|
||||
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION)
|
||||
assert SWAP_MAX_DIM > 0
|
||||
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION
|
||||
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
||||
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
||||
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
||||
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
||||
|
@ -406,8 +404,13 @@ class SRGANModel(BaseModel):
|
|||
if _profile:
|
||||
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
# Append var_ref here, so that we can inspect the alterations the disc made if pixgan
|
||||
var_ref_skips.append(var_ref[0].detach())
|
||||
self.fake_H.append(fake_H[0].detach())
|
||||
self.optimizer_D.step()
|
||||
|
||||
|
||||
if _profile:
|
||||
print("Disc step %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
@ -432,8 +435,8 @@ class SRGANModel(BaseModel):
|
|||
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
||||
if multi_gen:
|
||||
utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "%05i_%02i.png" % (step, i)))
|
||||
else:
|
||||
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||
|
|
Loading…
Reference in New Issue
Block a user