diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index cc0094c6..2bfdbda3 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -338,31 +338,38 @@ class SRGANModel(BaseModel): with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() if self.opt['train']['gan_type'] == 'pixgan': + # randomly determine portions of the image to swap to keep the discriminator honest. + # We're making some assumptions about the underlying pixel-discriminator here. This is a # necessary evil for now, but if this turns out well we might want to make this configurable. PIXDISC_CHANNELS = 3 PIXDISC_OUTPUT_REDUCTION = 8 - PIXDISC_MAX_REDUCTION = 32 disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION) - real = torch.ones(disc_output_shape, device=var_ref[0].device) - fake = torch.zeros(disc_output_shape, device=var_ref[0].device) - - # randomly determine portions of the image to swap to keep the discriminator honest. - if random.random() > .25: + b, _, w, h = var_ref[0].shape + real = torch.ones((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device) + fake = torch.zeros((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device) + SWAP_MAX_DIM = w // 4 + SWAP_MIN_DIM = 16 + assert SWAP_MAX_DIM > 0 + random_swap_count = random.randint(0, 4) + for i in range(random_swap_count): # Make the swap across fake_H and var_ref - SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - assert SWAP_MAX_DIM > 0 - swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION - swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION + swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM) + swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM) + if swap_x + swap_w > w: + swap_w = w - swap_x + if swap_y + swap_h > h: + swap_h = h - swap_y t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t - - # Swap the expectation matrix too. - swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 + # Interpolate down to the dimensionality that the discriminator uses. + real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear") + fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear") + # We're also assuming that this is exactly how the flattened discriminator output is generated. real = real.view(-1, 1) fake = fake.view(-1, 1) diff --git a/codes/train.py b/codes/train.py index 9259c31a..4744ce26 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_srg2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)