Fix skips & images samples

- Makes skip connections between the generator and discriminator more
  extensible by adding additional configuration options for them and supporting
  1 and 0 skips.
- Places the temp/ directory with sample images from the training process appear
  in the training directory instead of the codes/ directory.
This commit is contained in:
James Betker 2020-05-15 13:50:49 -06:00
parent cdf641e3e2
commit a33ec3e22b
5 changed files with 56 additions and 36 deletions

View File

@ -214,10 +214,8 @@ class SRGANModel(BaseModel):
# the first element of the tuple.
if isinstance(fake_GenOut, tuple):
gen_img = fake_GenOut[0]
# TODO: Fix this.
self.fake_GenOut.append((fake_GenOut[0].detach(),
fake_GenOut[1].detach(),
fake_GenOut[2].detach()))
# The following line detaches all generator outputs that are not None.
self.fake_GenOut.append(tuple([(x.detach() if x is not None else None) for x in list(fake_GenOut)]))
var_ref = (var_ref,) + self.create_artificial_skips(var_H)
else:
gen_img = fake_GenOut
@ -269,7 +267,8 @@ class SRGANModel(BaseModel):
# Re-compute generator outputs (post-update).
with torch.no_grad():
fake_H = self.netG(var_L)
fake_H = (fake_H[0].detach(), fake_H[1].detach(), fake_H[2].detach())
# The following line detaches all generator outputs that are not None.
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
# Apply noise to the inputs to slow discriminator convergence.
var_ref = (var_ref[0] + noise,) + var_ref[1:]
@ -306,35 +305,38 @@ class SRGANModel(BaseModel):
# Log sample images from first microbatch.
if step % 50 == 0:
os.makedirs("temp/hr", exist_ok=True)
os.makedirs("temp/lr", exist_ok=True)
os.makedirs("temp/lr_precorrupt", exist_ok=True)
os.makedirs("temp/gen", exist_ok=True)
os.makedirs("temp/pix", exist_ok=True)
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "lr_precorrupt"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
multi_gen = False
if isinstance(self.fake_GenOut[0], tuple):
os.makedirs("temp/genlr", exist_ok=True)
os.makedirs("temp/genmr", exist_ok=True)
os.makedirs("temp/ref", exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "genlr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "genmr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
multi_gen = True
# fed_LQ is not chunked.
utils.save_image(self.fed_LQ.cpu().detach(), os.path.join("temp/lr_precorrupt", "%05i.png" % (step,)))
utils.save_image(self.fed_LQ.cpu().detach(), os.path.join(sample_save_path, "lr_precorrupt", "%05i.png" % (step,)))
for i in range(self.mega_batch_factor):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_H[i].cpu().detach(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[i].cpu().detach(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
if multi_gen:
utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join("temp/genmr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join("temp/genlr", "%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][0].cpu().detach(), os.path.join("temp/ref", "hi_%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join("temp/ref", "med_%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join("temp/ref", "low_%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
if self.fake_GenOut[i][1] is not None:
utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join(sample_save_path, "genmr", "%05i_%02i.png" % (step, i)))
if self.fake_GenOut[i][2] is not None:
utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join(sample_save_path, "genlr", "%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][0].cpu().detach(), os.path.join(sample_save_path, "ref", "hi_%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join(sample_save_path, "ref", "med_%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join(sample_save_path, "ref", "low_%05i_%02i.png" % (step, i)))
else:
utils.save_image(self.fake_GenOut[i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_GenOut[i].cpu().detach(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
# set log TODO(handle mega-batches?)
# Log metrics
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix:
self.add_log_entry('l_g_pix', l_g_pix.item())
@ -346,6 +348,7 @@ class SRGANModel(BaseModel):
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor)
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
self.add_log_entry('noise_theta', noise_theta)
if step % self.corruptor_swapout_steps == 0 and step > 0:
@ -398,13 +401,13 @@ class SRGANModel(BaseModel):
self.swapout_G_duration -= 1
if self.swapout_G_duration == 0:
# Swap back.
print("Swapping back to current G model: %s" % (self.stashed_G,))
logger.info("Swapping back to current G model: %s" % (self.stashed_G,))
self.load_network(self.stashed_G, self.netG, self.opt['path']['strict_load'])
self.stashed_G = None
elif self.swapout_G_freq != 0 and step % self.swapout_G_freq == 0:
swapped_model = self.pick_rand_prev_model('G')
if swapped_model is not None:
print("Swapping to previous G model: %s" % (swapped_model,))
logger.info("Swapping to previous G model: %s" % (swapped_model,))
self.stashed_G = self.save_network(self.netG, 'G', 'swap_model')
self.load_network(swapped_model, self.netG, self.opt['path']['strict_load'])
self.swapout_G_duration = self.swapout_duration

View File

@ -107,15 +107,18 @@ class FixupBottleneck(nn.Module):
class FixupResNet(nn.Module):
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, use_bn=False):
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False):
super(FixupResNet, self).__init__()
self.num_layers = sum(layers)
self.inplanes = 3
self.number_skips = number_skips
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5)
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
if number_skips > 0:
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
self.layer1 = self._make_layer(block, num_filters*4, layers[1], stride=2, use_bn=use_bn, conv_type=conv5x5)
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
if number_skips > 1:
self.inplanes = self.inplanes + 3 # Accomodate a second skip connection from the generator.
self.layer2 = self._make_layer(block, num_filters*8, layers[2], stride=2, use_bn=use_bn)
# SRGAN already has a feature loss tied to a separate VGG discriminator. We really don't care about features.
# Therefore, level off the filter count from this block forwards.
@ -157,9 +160,11 @@ class FixupResNet(nn.Module):
x, med_skip, lo_skip = x
x = self.layer0(x)
x = torch.cat([x, med_skip], dim=1)
if self.number_skips > 0:
x = torch.cat([x, med_skip], dim=1)
x = self.layer1(x)
x = torch.cat([x, lo_skip], dim=1)
if self.number_skips > 1:
x = torch.cat([x, lo_skip], dim=1)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

View File

@ -177,13 +177,23 @@ class FixupResNetV2(FixupResNet):
if self.upscale_applications > 0:
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
x = self.layer2(x)
skip_med = self.filter_to_image(x)
skip_med = self.filter_to_image(x)
if self.upscale_applications > 1:
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
x = self.layer2(x)
x = self.filter_to_image(x)
if self.upscale_applications == 2:
x = self.filter_to_image(x)
elif self.upscale_applications == 1:
x = skip_med
skip_med = skip_lo
skip_lo = None
elif self.upscale_applications == 0:
x = skip_lo
skip_lo = None
skip_med = None
return x, skip_med, skip_lo
def fixup_resnet34(nb_denoiser=20, nb_upsampler=10, **kwargs):

View File

@ -75,7 +75,8 @@ def define_D(opt):
elif which_model == 'discriminator_resnet':
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_passthrough':
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, use_bn=True)
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
number_skips=opt_net['number_skips'], use_bn=True)
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_vix_resgenv2.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vix_resgenv2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
@ -201,6 +201,7 @@ def main():
model.test()
visuals = model.get_current_visuals()
sr_img = util.tensor2img(visuals['rlt']) # uint8
gt_img = util.tensor2img(visuals['GT']) # uint8