Update pixgan swap algorithm

- Swap multiple blocks in the image instead of just one. The discriminator was clearly
  learning that most blocks have one region that needs to be fixed.
- Relax block size constraints. This was in place to gaurantee that the discriminator
  signal was clean. Instead, just downsample the "loss image" with bilinear interpolation.
  The result is noisier, but this is actually probably healthy for the discriminator.
This commit is contained in:
James Betker 2020-07-10 15:56:14 -06:00
parent 33ca3832e1
commit 812c684f7d
2 changed files with 21 additions and 14 deletions

View File

@ -338,31 +338,38 @@ class SRGANModel(BaseModel):
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
if self.opt['train']['gan_type'] == 'pixgan':
# randomly determine portions of the image to swap to keep the discriminator honest.
# We're making some assumptions about the underlying pixel-discriminator here. This is a
# necessary evil for now, but if this turns out well we might want to make this configurable.
PIXDISC_CHANNELS = 3
PIXDISC_OUTPUT_REDUCTION = 8
PIXDISC_MAX_REDUCTION = 32
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
real = torch.ones(disc_output_shape, device=var_ref[0].device)
fake = torch.zeros(disc_output_shape, device=var_ref[0].device)
# randomly determine portions of the image to swap to keep the discriminator honest.
if random.random() > .25:
b, _, w, h = var_ref[0].shape
real = torch.ones((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device)
fake = torch.zeros((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device)
SWAP_MAX_DIM = w // 4
SWAP_MIN_DIM = 16
assert SWAP_MAX_DIM > 0
random_swap_count = random.randint(0, 4)
for i in range(random_swap_count):
# Make the swap across fake_H and var_ref
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION)
assert SWAP_MAX_DIM > 0
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM)
swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM)
if swap_x + swap_w > w:
swap_w = w - swap_x
if swap_y + swap_h > h:
swap_h = h - swap_y
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
# Swap the expectation matrix too.
swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
# Interpolate down to the dimensionality that the discriminator uses.
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear")
fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear")
# We're also assuming that this is exactly how the flattened discriminator output is generated.
real = real.view(-1, 1)
fake = fake.view(-1, 1)

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_srg2.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)