From 14d23b9d2008e637dea5d4ce0520ec548b977e42 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 11 Jul 2020 21:22:11 -0600 Subject: [PATCH] Fixes, do fake swaps less often in pixgan discriminator --- codes/models/SRGAN_model.py | 32 +++++++++++-------- .../archs/SwitchedResidualGenerator_arch.py | 8 ++--- codes/utils/numeric_stability.py | 10 ++---- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index c729624f..f95bbfdb 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -347,20 +347,24 @@ class SRGANModel(BaseModel): SWAP_MAX_DIM = w // 4 SWAP_MIN_DIM = 16 assert SWAP_MAX_DIM > 0 - random_swap_count = random.randint(0, 4) - for i in range(random_swap_count): - # Make the swap across fake_H and var_ref - swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM) - swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM) - if swap_x + swap_w > w: - swap_w = w - swap_x - if swap_y + swap_h > h: - swap_h = h - swap_y - t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() - fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] - var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t - real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 - fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 + if random.random() > .5: # Make this only happen half the time. Earlier experiments had it happen + # more often and the model was "cheating" by using the presence of + # easily discriminated fake swaps to count the entire generated image + # as fake. + random_swap_count = random.randint(0, 4) + for i in range(random_swap_count): + # Make the swap across fake_H and var_ref + swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM) + swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM) + if swap_x + swap_w > w: + swap_w = w - swap_x + if swap_y + swap_h > h: + swap_h = h - swap_y + t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() + fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] + var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t + real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 + fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 # Interpolate down to the dimensionality that the discriminator uses. real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear") diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index d3253d3f..ef63ab60 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -11,12 +11,12 @@ from switched_conv_util import save_attention_to_image class MultiConvBlock(nn.Module): - def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False, weight_init_factor=1): + def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, norm=False, weight_init_factor=1): assert depth >= 2 super(MultiConvBlock, self).__init__() self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01)) - self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=bn, bias=False, weight_init_factor=weight_init_factor)] + - [ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=bn, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] + + self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] + + [ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] + [ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)]) self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init)) self.bias = nn.Parameter(torch.zeros(1)) @@ -167,7 +167,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) for _ in range(switch_depth): multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts) - pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False, weight_init_factor=.1) + pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1) switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, diff --git a/codes/utils/numeric_stability.py b/codes/utils/numeric_stability.py index 4aa40ba9..bfcedc70 100644 --- a/codes/utils/numeric_stability.py +++ b/codes/utils/numeric_stability.py @@ -78,11 +78,7 @@ def test_numeric_stability(mod: nn.Module, format, iterations=50, device='cuda') stds.append(torch.std(measure).detach()) return torch.stack(means), torch.stack(stds) -''' - def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, - trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, - heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False): - ''' + if __name__ == "__main__": ''' test_stability(functools.partial(nsg.NestedSwitchedGenerator, @@ -107,8 +103,8 @@ if __name__ == "__main__": trans_layers=4, transformation_filters=64, upsample_factor=4), - torch.randn(1, 3, 64, 64), - device='cuda') + torch.randn(1, 3, 32, 32), + device='cpu') ''' test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,