Bug fixes and new gan mechanism

- Removed a bunch of unnecessary image loggers. These were just consuming space and never being viewed
- Got rid of support of artificial var_ref support. The new pixdisc is what i wanted to implement then - it's much better.
- Add pixgan GAN mechanism. This is purpose-built for the pixdisc. It is intended to promote a healthy discriminator
- Megabatchfactor was applied twice on metrics, fixed that

Adds pix_gan (untested) which swaps a portion of the fake and real image with each other, then expects the discriminator
to properly discriminate the swapped regions.
This commit is contained in:
James Betker 2020-07-08 17:40:26 -06:00
parent 4305be97b4
commit 26a4a66d1c
5 changed files with 77 additions and 55 deletions

View File

@ -20,8 +20,8 @@ def main():
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed. # compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single': if mode == 'single':
opt['input_folder'] = 'F:\\4k6k\\datasets\\div2k\\DIV2K_train_HR' opt['input_folder'] = 'F:\\4k6k\\datasets\\flickr2k\\Flickr2K_HR'
opt['save_folder'] = 'F:\\4k6k\\datasets\\div2k\\tiled1024' opt['save_folder'] = 'F:\\4k6k\\datasets\\flickr2k\\1024px'
opt['crop_sz'] = 1024 # the size of each sub-image opt['crop_sz'] = 1024 # the size of each sub-image
opt['step'] = 880 # step of the sliding crop window opt['step'] = 880 # step of the sliding crop window
opt['thres_sz'] = 240 # size threshold opt['thres_sz'] = 240 # size threshold

View File

@ -227,14 +227,10 @@ class SRGANModel(BaseModel):
_t = time() _t = time()
self.fake_GenOut = [] self.fake_GenOut = []
self.fake_H = []
var_ref_skips = [] var_ref_skips = []
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
#from utils import gpu_mem_track
#import inspect
#gpu_tracker = gpu_mem_track.MemTracker(inspect.currentframe())
#gpu_tracker.track()
fake_GenOut = self.netG(var_L) fake_GenOut = self.netG(var_L)
#gpu_tracker.track()
if _profile: if _profile:
print("Gen forward %f" % (time() - _t,)) print("Gen forward %f" % (time() - _t,))
@ -246,11 +242,11 @@ class SRGANModel(BaseModel):
gen_img = fake_GenOut[0] gen_img = fake_GenOut[0]
# The following line detaches all generator outputs that are not None. # The following line detaches all generator outputs that are not None.
self.fake_GenOut.append(tuple([(x.detach() if x is not None else None) for x in list(fake_GenOut)])) self.fake_GenOut.append(tuple([(x.detach() if x is not None else None) for x in list(fake_GenOut)]))
var_ref = (var_ref,) + self.create_artificial_skips(var_H) var_ref = (var_ref,) # This is a tuple for legacy reasons.
else: else:
gen_img = fake_GenOut gen_img = fake_GenOut
self.fake_GenOut.append(fake_GenOut.detach()) self.fake_GenOut.append(fake_GenOut.detach())
var_ref_skips.append(var_ref) var_ref_skips.append(var_ref[0].detach())
l_g_total = 0 l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
@ -324,8 +320,9 @@ class SRGANModel(BaseModel):
_t = time() _t = time()
# Apply noise to the inputs to slow discriminator convergence. # Apply noise to the inputs to slow discriminator convergence.
var_ref = (var_ref[0] + noise,) + var_ref[1:] var_ref = (var_ref[0] + noise,)
fake_H = (fake_H[0] + noise,) + fake_H[1:] fake_H = (fake_H[0] + noise,) + fake_H[1:]
self.fake_H.append(fake_H[0].detach())
if self.opt['train']['gan_type'] == 'gan': if self.opt['train']['gan_type'] == 'gan':
# need to forward and backward separately, since batch norm statistics differ # need to forward and backward separately, since batch norm statistics differ
# real # real
@ -340,13 +337,50 @@ class SRGANModel(BaseModel):
l_d_fake_log = l_d_fake * self.mega_batch_factor l_d_fake_log = l_d_fake * self.mega_batch_factor
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
if self.opt['train']['gan_type'] == 'pixgan':
# We're making some assumptions about the underlying pixel-discriminator here. This is a
# necessary evil for now, but if this turns out well we might want to make this configurable.
PIXDISC_CHANNELS = 3
PIXDISC_OUTPUT_REDUCTION = 8
PIXDISC_MAX_REDUCTION = 32
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
real = torch.ones(disc_output_shape)
fake = torch.zeros(disc_output_shape)
# randomly determine portions of the image to swap to keep the discriminator honest.
if random.random() > .25:
# Make the swap across fake_H and var_ref
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1
assert SWAP_MAX_DIM > 0
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
t = fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h].clone()
fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h]
var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = t
# Swap the expectation matrix too.
swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION
real[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 0.0
fake[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 1.0
# We're also assuming that this is exactly how the flattened discriminator output is generated.
real = real.view(-1, 1)
fake = fake.view(-1, 1)
# real
pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor
l_d_real_log = l_d_real * self.mega_batch_factor
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
# fake
pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor
l_d_fake_log = l_d_fake * self.mega_batch_factor
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
# pred_d_real = self.netD(var_ref)
# pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
# l_d_total = (l_d_real + l_d_fake) / 2
# l_d_total.backward()
pred_d_fake = self.netD(fake_H).detach() pred_d_fake = self.netD(fake_H).detach()
pred_d_real = self.netD(var_ref) pred_d_real = self.netD(var_ref)
@ -383,34 +417,26 @@ class SRGANModel(BaseModel):
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp") sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "lr_precorrupt"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
multi_gen = False multi_gen = False
if isinstance(self.fake_GenOut[0], tuple): if isinstance(self.fake_GenOut[0], tuple):
os.makedirs(os.path.join(sample_save_path, "genlr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "genmr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
multi_gen = True multi_gen = True
# fed_LQ is not chunked. # fed_LQ is not chunked.
utils.save_image(self.fed_LQ.cpu().detach(), os.path.join(sample_save_path, "lr_precorrupt", "%05i.png" % (step,)))
for i in range(self.mega_batch_factor): for i in range(self.mega_batch_factor):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i))) utils.save_image(self.var_H[i].cpu(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i))) utils.save_image(self.var_L[i].cpu(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[i].cpu().detach(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
if multi_gen: if multi_gen:
utils.save_image(self.fake_GenOut[i][0].cpu().detach(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
if len(self.fake_GenOut[i]) > 1: utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
if self.fake_GenOut[i][1] is not None: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
utils.save_image(self.fake_GenOut[i][1].cpu().detach(), os.path.join(sample_save_path, "genmr", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "%05i_%02i.png" % (step, i)))
if self.fake_GenOut[i][2] is not None:
utils.save_image(self.fake_GenOut[i][2].cpu().detach(), os.path.join(sample_save_path, "genlr", "%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][0].cpu().detach(), os.path.join(sample_save_path, "ref", "hi_%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][1].cpu().detach(), os.path.join(sample_save_path, "ref", "med_%05i_%02i.png" % (step, i)))
utils.save_image(var_ref_skips[i][2].cpu().detach(), os.path.join(sample_save_path, "ref", "low_%05i_%02i.png" % (step, i)))
else: else:
utils.save_image(self.fake_GenOut[i].cpu().detach(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
# Log metrics # Log metrics
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
@ -421,10 +447,10 @@ class SRGANModel(BaseModel):
self.add_log_entry('l_g_fea', l_g_fea_log.item()) self.add_log_entry('l_g_fea', l_g_fea_log.item())
if self.l_gan_w > 0: if self.l_gan_w > 0:
self.add_log_entry('l_g_gan', l_g_gan_log.item()) self.add_log_entry('l_g_gan', l_g_gan_log.item())
self.add_log_entry('l_g_total', l_g_total_log.item() * self.mega_batch_factor) self.add_log_entry('l_g_total', l_g_total_log.item())
if self.l_gan_w > 0 and step > self.G_warmup: if self.l_gan_w > 0 and step > self.G_warmup:
self.add_log_entry('l_d_real', l_d_real_log.item() * self.mega_batch_factor) self.add_log_entry('l_d_real', l_d_real_log.item())
self.add_log_entry('l_d_fake', l_d_fake_log.item() * self.mega_batch_factor) self.add_log_entry('l_d_fake', l_d_fake_log.item())
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
@ -444,12 +470,6 @@ class SRGANModel(BaseModel):
self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value
self.log_dict[key_it] += 1 self.log_dict[key_it] += 1
def create_artificial_skips(self, truth_img):
med_skip = F.interpolate(truth_img, scale_factor=.5)
lo_skip = F.interpolate(truth_img, scale_factor=.25)
return med_skip, lo_skip
def pick_rand_prev_model(self, model_suffix): def pick_rand_prev_model(self, model_suffix):
previous_models = glob.glob(os.path.join(self.opt['path']['models'], "*_%s.pth" % (model_suffix,))) previous_models = glob.glob(os.path.join(self.opt['path']['models'], "*_%s.pth" % (model_suffix,)))
if len(previous_models) <= 1: if len(previous_models) <= 1:

View File

@ -126,7 +126,7 @@ class Discriminator_VGG_PixLoss(nn.Module):
# activation function # activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x): def forward(self, x, flatten=True):
x = x[0] x = x[0]
fea0 = self.lrelu(self.conv0_0(x)) fea0 = self.lrelu(self.conv0_0(x))
fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0))) fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0)))
@ -144,9 +144,9 @@ class Discriminator_VGG_PixLoss(nn.Module):
fea4 = self.lrelu(self.bn4_1(self.conv4_1(fea4))) fea4 = self.lrelu(self.bn4_1(self.conv4_1(fea4)))
loss = self.reduce_1(fea4) loss = self.reduce_1(fea4)
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will # "Weight" all losses the same by interpolating them to the highest dimension.
# then know how to handle them. loss = self.pix_loss_collapse(loss)
loss = self.pix_loss_collapse(loss).view(-1, 1) loss = F.interpolate(loss, scale_factor=4, mode="nearest")
# And the pyramid network! # And the pyramid network!
dec3 = self.up3_decimate(F.interpolate(fea4, scale_factor=2, mode="nearest")) dec3 = self.up3_decimate(F.interpolate(fea4, scale_factor=2, mode="nearest"))
@ -154,17 +154,17 @@ class Discriminator_VGG_PixLoss(nn.Module):
dec3 = self.up3_converge(dec3) dec3 = self.up3_converge(dec3)
dec3 = self.up3_proc(dec3) dec3 = self.up3_proc(dec3)
loss3 = self.up3_reduce(dec3) loss3 = self.up3_reduce(dec3)
loss3 = self.up3_pix(loss3).view(-1, 1) loss3 = self.up3_pix(loss3)
loss3 = F.interpolate(loss3, scale_factor=2, mode="nearest")
dec2 = self.up2_decimate(F.interpolate(dec3, scale_factor=2, mode="nearest")) dec2 = self.up2_decimate(F.interpolate(dec3, scale_factor=2, mode="nearest"))
dec2 = torch.cat([dec2, fea2], dim=1) dec2 = torch.cat([dec2, fea2], dim=1)
dec2 = self.up2_converge(dec2) dec2 = self.up2_converge(dec2)
dec2 = self.up2_proc(dec2) dec2 = self.up2_proc(dec2)
loss2 = self.up2_reduce(dec2) loss2 = self.up2_reduce(dec2)
loss2 = self.up2_pix(loss2).view(-1, 1)
# "Weight" all losses the same by repeating the LR losses to the HR dim.
combined_losses = torch.cat([loss.repeat((16, 1)), loss3.repeat((4, 1)), loss2])
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
# then know how to handle them.
combined_losses = torch.cat([loss, loss3, loss2], dim=1)
return combined_losses.view(-1, 1) return combined_losses.view(-1, 1)

View File

@ -23,7 +23,7 @@ class GANLoss(nn.Module):
self.real_label_val = real_label_val self.real_label_val = real_label_val
self.fake_label_val = fake_label_val self.fake_label_val = fake_label_val
if self.gan_type == 'gan' or self.gan_type == 'ragan': if self.gan_type == 'gan' or self.gan_type == 'ragan' or self.gan_type == 'pixgan':
self.loss = nn.BCEWithLogitsLoss() self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan': elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss() self.loss = nn.MSELoss()
@ -46,7 +46,10 @@ class GANLoss(nn.Module):
return torch.empty_like(input).fill_(self.fake_label_val) return torch.empty_like(input).fill_(self.fake_label_val)
def forward(self, input, target_is_real): def forward(self, input, target_is_real):
target_label = self.get_target_label(input, target_is_real) if self.gan_type == 'pixgan':
target_label = target_is_real
else:
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label) loss = self.loss(input, target_label)
return loss return loss

View File

@ -29,11 +29,10 @@ def init_dist(backend='nccl', **kwargs):
torch.cuda.set_device(rank % num_gpus) torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs) dist.init_process_group(backend=backend, **kwargs)
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_srg3.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_rrdb_pixgan_normal_gan.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)