Fixes, do fake swaps less often in pixgan discriminator

This commit is contained in:
James Betker 2020-07-11 21:22:11 -06:00
parent ba6187859a
commit 14d23b9d20
3 changed files with 25 additions and 25 deletions

View File

@ -347,20 +347,24 @@ class SRGANModel(BaseModel):
SWAP_MAX_DIM = w // 4
SWAP_MIN_DIM = 16
assert SWAP_MAX_DIM > 0
random_swap_count = random.randint(0, 4)
for i in range(random_swap_count):
# Make the swap across fake_H and var_ref
swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM)
swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM)
if swap_x + swap_w > w:
swap_w = w - swap_x
if swap_y + swap_h > h:
swap_h = h - swap_y
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
if random.random() > .5: # Make this only happen half the time. Earlier experiments had it happen
# more often and the model was "cheating" by using the presence of
# easily discriminated fake swaps to count the entire generated image
# as fake.
random_swap_count = random.randint(0, 4)
for i in range(random_swap_count):
# Make the swap across fake_H and var_ref
swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM)
swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM)
if swap_x + swap_w > w:
swap_w = w - swap_x
if swap_y + swap_h > h:
swap_h = h - swap_y
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
# Interpolate down to the dimensionality that the discriminator uses.
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear")

View File

@ -11,12 +11,12 @@ from switched_conv_util import save_attention_to_image
class MultiConvBlock(nn.Module):
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False, weight_init_factor=1):
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, norm=False, weight_init_factor=1):
assert depth >= 2
super(MultiConvBlock, self).__init__()
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=bn, bias=False, weight_init_factor=weight_init_factor)] +
[ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=bn, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] +
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] +
[ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] +
[ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)])
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
self.bias = nn.Parameter(torch.zeros(1))
@ -167,7 +167,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
for _ in range(switch_depth):
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts)
pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False, weight_init_factor=.1)
pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,

View File

@ -78,11 +78,7 @@ def test_numeric_stability(mod: nn.Module, format, iterations=50, device='cuda')
stds.append(torch.std(measure).detach())
return torch.stack(means), torch.stack(stds)
'''
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
'''
if __name__ == "__main__":
'''
test_stability(functools.partial(nsg.NestedSwitchedGenerator,
@ -107,8 +103,8 @@ if __name__ == "__main__":
trans_layers=4,
transformation_filters=64,
upsample_factor=4),
torch.randn(1, 3, 64, 64),
device='cuda')
torch.randn(1, 3, 32, 32),
device='cpu')
'''
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,