forked from mrq/DL-Art-School
Update pixgan swap algorithm
- Swap multiple blocks in the image instead of just one. The discriminator was clearly learning that most blocks have one region that needs to be fixed. - Relax block size constraints. This was in place to gaurantee that the discriminator signal was clean. Instead, just downsample the "loss image" with bilinear interpolation. The result is noisier, but this is actually probably healthy for the discriminator.
This commit is contained in:
parent
33ca3832e1
commit
812c684f7d
|
@ -338,31 +338,38 @@ class SRGANModel(BaseModel):
|
||||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||||
l_d_fake_scaled.backward()
|
l_d_fake_scaled.backward()
|
||||||
if self.opt['train']['gan_type'] == 'pixgan':
|
if self.opt['train']['gan_type'] == 'pixgan':
|
||||||
|
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||||
|
|
||||||
# We're making some assumptions about the underlying pixel-discriminator here. This is a
|
# We're making some assumptions about the underlying pixel-discriminator here. This is a
|
||||||
# necessary evil for now, but if this turns out well we might want to make this configurable.
|
# necessary evil for now, but if this turns out well we might want to make this configurable.
|
||||||
PIXDISC_CHANNELS = 3
|
PIXDISC_CHANNELS = 3
|
||||||
PIXDISC_OUTPUT_REDUCTION = 8
|
PIXDISC_OUTPUT_REDUCTION = 8
|
||||||
PIXDISC_MAX_REDUCTION = 32
|
|
||||||
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
|
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
|
||||||
real = torch.ones(disc_output_shape, device=var_ref[0].device)
|
b, _, w, h = var_ref[0].shape
|
||||||
fake = torch.zeros(disc_output_shape, device=var_ref[0].device)
|
real = torch.ones((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device)
|
||||||
|
fake = torch.zeros((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device)
|
||||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
SWAP_MAX_DIM = w // 4
|
||||||
if random.random() > .25:
|
SWAP_MIN_DIM = 16
|
||||||
# Make the swap across fake_H and var_ref
|
|
||||||
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION)
|
|
||||||
assert SWAP_MAX_DIM > 0
|
assert SWAP_MAX_DIM > 0
|
||||||
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
random_swap_count = random.randint(0, 4)
|
||||||
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
for i in range(random_swap_count):
|
||||||
|
# Make the swap across fake_H and var_ref
|
||||||
|
swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM)
|
||||||
|
swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM)
|
||||||
|
if swap_x + swap_w > w:
|
||||||
|
swap_w = w - swap_x
|
||||||
|
if swap_y + swap_h > h:
|
||||||
|
swap_h = h - swap_y
|
||||||
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
||||||
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
||||||
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
|
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
|
||||||
|
|
||||||
# Swap the expectation matrix too.
|
|
||||||
swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION
|
|
||||||
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
|
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
|
||||||
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
||||||
|
|
||||||
|
# Interpolate down to the dimensionality that the discriminator uses.
|
||||||
|
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear")
|
||||||
|
fake = F.interpolate(fake, size=disc_output_shape[2:], mode="bilinear")
|
||||||
|
|
||||||
# We're also assuming that this is exactly how the flattened discriminator output is generated.
|
# We're also assuming that this is exactly how the flattened discriminator output is generated.
|
||||||
real = real.view(-1, 1)
|
real = real.view(-1, 1)
|
||||||
fake = fake.view(-1, 1)
|
fake = fake.view(-1, 1)
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_srg2.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user