forked from mrq/DL-Art-School
Fix skips & images samples
- Makes skip connections between the generator and discriminator more extensible by adding additional configuration options for them and supporting 1 and 0 skips. - Places the temp/ directory with sample images from the training process appear in the training directory instead of the codes/ directory.
This commit is contained in:
parent
cdf641e3e2
commit
a33ec3e22b
|
@ -214,10 +214,8 @@ class SRGANModel(BaseModel):
|
|||
# the first element of the tuple.
|
||||
if isinstance(fake_GenOut, tuple):
|
||||
gen_img = fake_GenOut[0]
|
||||
# TODO: Fix this.
|
||||
self.fake_GenOut.append((fake_GenOut[0].detach(),
|
||||
fake_GenOut[1].detach(),
|
||||
fake_GenOut[2].detach()))
|
||||
# The following line detaches all generator outputs that are not None.
|
||||
self.fake_GenOut.append(tuple([(x.detach() if x is not None else None) for x in list(fake_GenOut)]))
|
||||
var_ref = (var_ref,) + self.create_artificial_skips(var_H)
|
||||
else:
|
||||
gen_img = fake_GenOut
|
||||
|
@ -269,7 +267,8 @@ class SRGANModel(BaseModel):
|
|||
# Re-compute generator outputs (post-update).
|
||||
with torch.no_grad():
|
||||
fake_H = self.netG(var_L)
|
||||
fake_H = (fake_H[0].detach(), fake_H[1].detach(), fake_H[2].detach())
|
||||
# The following line detaches all generator outputs that are not None.
|
||||
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
|
||||
|
||||
# Apply noise to the inputs to slow discriminator convergence.
|
||||
var_ref = (var_ref[0] + noise,) + var_ref[1:]
|
||||
|
@ -306,35 +305,38 @@ class SRGANModel(BaseModel):
|
|||
|
||||
# Log sample images from first microbatch.
|
||||
if step % 50 == 0:
|
||||
os.makedirs("temp/hr", exist_ok=True)
|
||||
os.makedirs("temp/lr", exist_ok=True)
|
||||
os.makedirs("temp/lr_precorrupt", exist_ok=True)
|
||||
os.makedirs("temp/gen", exist_ok=True)
|
||||
os.makedirs("temp/pix", exist_ok=True)
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
||||
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "lr_precorrupt"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
|
||||
multi_gen = False
|
||||
if isinstance(self.fake_GenOut[0], tuple):
|
||||
os.makedirs("temp/genlr", exist_ok=True)
|
||||
os.makedirs("temp/genmr", exist_ok=True)
|
||||
os.makedirs("temp/ref", exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "genlr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "genmr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
|
||||
multi_gen = True
|
||||
|
||||
# fed_LQ is not chunked.
|
||||
utils.save_image(self.fed_LQ.cpu().detach(), os.path.join("temp/lr_precorrupt", "%05i.png" % (step,)))
|
||||
utils.save_image(self.fed_LQ.cpu().detach(), os.path.join(sample_save_path, "lr_precorrupt", "%05i.png" % (step,)))
|
||||
for i in range(self.mega_batch_factor):
|
||||
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.pix[i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_H[i].cpu().detach(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[i].cpu().detach(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.pix[i].cpu().detach(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
||||
if multi_gen:
|
||||
utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join("temp/genmr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join("temp/genlr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(var_ref_skips[i][0].cpu().detach(), os.path.join("temp/ref", "hi_%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join("temp/ref", "med_%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join("temp/ref", "low_%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||
if self.fake_GenOut[i][1] is not None:
|
||||
utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join(sample_save_path, "genmr", "%05i_%02i.png" % (step, i)))
|
||||
if self.fake_GenOut[i][2] is not None:
|
||||
utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join(sample_save_path, "genlr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(var_ref_skips[i][0].cpu().detach(), os.path.join(sample_save_path, "ref", "hi_%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join(sample_save_path, "ref", "med_%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join(sample_save_path, "ref", "low_%05i_%02i.png" % (step, i)))
|
||||
else:
|
||||
utils.save_image(self.fake_GenOut[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_GenOut[i].cpu().detach(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||
|
||||
# set log TODO(handle mega-batches?)
|
||||
# Log metrics
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix:
|
||||
self.add_log_entry('l_g_pix', l_g_pix.item())
|
||||
|
@ -346,6 +348,7 @@ class SRGANModel(BaseModel):
|
|||
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
|
||||
self.add_log_entry('noise_theta', noise_theta)
|
||||
|
||||
if step % self.corruptor_swapout_steps == 0 and step > 0:
|
||||
|
@ -398,13 +401,13 @@ class SRGANModel(BaseModel):
|
|||
self.swapout_G_duration -= 1
|
||||
if self.swapout_G_duration == 0:
|
||||
# Swap back.
|
||||
print("Swapping back to current G model: %s" % (self.stashed_G,))
|
||||
logger.info("Swapping back to current G model: %s" % (self.stashed_G,))
|
||||
self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load'])
|
||||
self.stashed_G = None
|
||||
elif self.swapout_G_freq != 0 and step % self.swapout_G_freq == 0:
|
||||
swapped_model = self.pick_rand_prev_model('G')
|
||||
if swapped_model is not None:
|
||||
print("Swapping to previous G model: %s" % (swapped_model,))
|
||||
logger.info("Swapping to previous G model: %s" % (swapped_model,))
|
||||
self.stashed_G = self.save_network(self.netG, 'G', 'swap_model')
|
||||
self.load_network(swapped_model, self.netG, self.opt['path']['strict_load'])
|
||||
self.swapout_G_duration = self.swapout_duration
|
||||
|
|
|
@ -107,15 +107,18 @@ class FixupBottleneck(nn.Module):
|
|||
|
||||
class FixupResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, use_bn=False):
|
||||
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False):
|
||||
super(FixupResNet, self).__init__()
|
||||
self.num_layers = sum(layers)
|
||||
self.inplanes = 3
|
||||
self.number_skips = number_skips
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5)
|
||||
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
|
||||
if number_skips > 0:
|
||||
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
|
||||
self.layer1 = self._make_layer(block, num_filters*4, layers[1], stride=2, use_bn=use_bn, conv_type=conv5x5)
|
||||
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
|
||||
if number_skips > 1:
|
||||
self.inplanes = self.inplanes + 3 # Accomodate a second skip connection from the generator.
|
||||
self.layer2 = self._make_layer(block, num_filters*8, layers[2], stride=2, use_bn=use_bn)
|
||||
# SRGAN already has a feature loss tied to a separate VGG discriminator. We really don't care about features.
|
||||
# Therefore, level off the filter count from this block forwards.
|
||||
|
@ -157,9 +160,11 @@ class FixupResNet(nn.Module):
|
|||
x, med_skip, lo_skip = x
|
||||
|
||||
x = self.layer0(x)
|
||||
x = torch.cat([x, med_skip], dim=1)
|
||||
if self.number_skips > 0:
|
||||
x = torch.cat([x, med_skip], dim=1)
|
||||
x = self.layer1(x)
|
||||
x = torch.cat([x, lo_skip], dim=1)
|
||||
if self.number_skips > 1:
|
||||
x = torch.cat([x, lo_skip], dim=1)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
|
|
@ -177,13 +177,23 @@ class FixupResNetV2(FixupResNet):
|
|||
if self.upscale_applications > 0:
|
||||
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
|
||||
x = self.layer2(x)
|
||||
skip_med = self.filter_to_image(x)
|
||||
|
||||
skip_med = self.filter_to_image(x)
|
||||
if self.upscale_applications > 1:
|
||||
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
|
||||
x = self.layer2(x)
|
||||
|
||||
x = self.filter_to_image(x)
|
||||
if self.upscale_applications == 2:
|
||||
x = self.filter_to_image(x)
|
||||
elif self.upscale_applications == 1:
|
||||
x = skip_med
|
||||
skip_med = skip_lo
|
||||
skip_lo = None
|
||||
elif self.upscale_applications == 0:
|
||||
x = skip_lo
|
||||
skip_lo = None
|
||||
skip_med = None
|
||||
|
||||
return x, skip_med, skip_lo
|
||||
|
||||
def fixup_resnet34(nb_denoiser=20, nb_upsampler=10, **kwargs):
|
||||
|
|
|
@ -75,7 +75,8 @@ def define_D(opt):
|
|||
elif which_model == 'discriminator_resnet':
|
||||
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||
elif which_model == 'discriminator_resnet_passthrough':
|
||||
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, use_bn=True)
|
||||
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
|
||||
number_skips=opt_net['number_skips'], use_bn=True)
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_vix_resgenv2.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vix_resgenv2.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -201,6 +201,7 @@ def main():
|
|||
model.test()
|
||||
|
||||
visuals = model.get_current_visuals()
|
||||
|
||||
sr_img = util.tensor2img(visuals['rlt']) # uint8
|
||||
gt_img = util.tensor2img(visuals['GT']) # uint8
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user