From 812c684f7d2d3c0966e7c432b4d38f18bb87af3b Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 10 Jul 2020 15:56:14 -0600 Subject: [PATCH] Update pixgan swap algorithm - Swap multiple blocks in the image instead of just one. The discriminator was clearly learning that most blocks have one region that needs to be fixed. - Relax block size constraints. This was in place to gaurantee that the discriminator signal was clean. Instead, just downsample the "loss image" with bilinear interpolation. The result is noisier, but this is actually probably healthy for the discriminator. --- codes/models/SRGAN_model.py | 33 ++++++++++++++++++++------------- codes/train.py | 2 +- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index cc0094c6..2bfdbda3 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -338,31 +338,38 @@ class SRGANModel(BaseModel): with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() if self.opt['train']['gan_type'] == 'pixgan': + # randomly determine portions of the image to swap to keep the discriminator honest. + # We're making some assumptions about the underlying pixel-discriminator here. This is a # necessary evil for now, but if this turns out well we might want to make this configurable. PIXDISC_CHANNELS = 3 PIXDISC_OUTPUT_REDUCTION = 8 - PIXDISC_MAX_REDUCTION = 32 disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION) - real = torch.ones(disc_output_shape, device=var_ref[0].device) - fake = torch.zeros(disc_output_shape, device=var_ref[0].device) - - # randomly determine portions of the image to swap to keep the discriminator honest. - if random.random() > .25: + b, _, w, h = var_ref[0].shape + real = torch.ones((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device) + fake = torch.zeros((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device) + SWAP_MAX_DIM = w // 4 + SWAP_MIN_DIM = 16 + assert SWAP_MAX_DIM > 0 + random_swap_count = random.randint(0, 4) + for i in range(random_swap_count): # Make the swap across fake_H and var_ref - SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - assert SWAP_MAX_DIM > 0 - swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION - swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION + swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM) + swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM) + if swap_x + swap_w > w: + swap_w = w - swap_x + if swap_y + swap_h > h: + swap_h = h - swap_y t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t - - # Swap the expectation matrix too. - swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 + # Interpolate down to the dimensionality that the discriminator uses. + real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear") + fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear") + # We're also assuming that this is exactly how the flattened discriminator output is generated. real = real.view(-1, 1) fake = fake.view(-1, 1) diff --git a/codes/train.py b/codes/train.py index 9259c31a..4744ce26 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_srg2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)