Update pixgan swap algorithm

- Swap multiple blocks in the image instead of just one. The discriminator was clearly
  learning that most blocks have one region that needs to be fixed.
- Relax block size constraints. This was in place to gaurantee that the discriminator
  signal was clean. Instead, just downsample the "loss image" with bilinear interpolation.
  The result is noisier, but this is actually probably healthy for the discriminator.
This commit is contained in:
James Betker 2020-07-10 15:56:14 -06:00
parent 33ca3832e1
commit 812c684f7d
2 changed files with 21 additions and 14 deletions

View File

@ -338,31 +338,38 @@ class SRGANModel(BaseModel):
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
if self.opt['train']['gan_type'] == 'pixgan': if self.opt['train']['gan_type'] == 'pixgan':
# randomly determine portions of the image to swap to keep the discriminator honest.
# We're making some assumptions about the underlying pixel-discriminator here. This is a # We're making some assumptions about the underlying pixel-discriminator here. This is a
# necessary evil for now, but if this turns out well we might want to make this configurable. # necessary evil for now, but if this turns out well we might want to make this configurable.
PIXDISC_CHANNELS = 3 PIXDISC_CHANNELS = 3
PIXDISC_OUTPUT_REDUCTION = 8 PIXDISC_OUTPUT_REDUCTION = 8
PIXDISC_MAX_REDUCTION = 32
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION) disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
real = torch.ones(disc_output_shape, device=var_ref[0].device) b, _, w, h = var_ref[0].shape
fake = torch.zeros(disc_output_shape, device=var_ref[0].device) real = torch.ones((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device)
fake = torch.zeros((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device)
# randomly determine portions of the image to swap to keep the discriminator honest. SWAP_MAX_DIM = w // 4
if random.random() > .25: SWAP_MIN_DIM = 16
assert SWAP_MAX_DIM > 0
random_swap_count = random.randint(0, 4)
for i in range(random_swap_count):
# Make the swap across fake_H and var_ref # Make the swap across fake_H and var_ref
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM)
assert SWAP_MAX_DIM > 0 swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM)
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION if swap_x + swap_w > w:
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION swap_w = w - swap_x
if swap_y + swap_h > h:
swap_h = h - swap_y
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
# Swap the expectation matrix too.
swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
# Interpolate down to the dimensionality that the discriminator uses.
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear")
fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear")
# We're also assuming that this is exactly how the flattened discriminator output is generated. # We're also assuming that this is exactly how the flattened discriminator output is generated.
real = real.view(-1, 1) real = real.view(-1, 1)
fake = fake.view(-1, 1) fake = fake.view(-1, 1)

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_srg2.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)