Fixes, do fake swaps less often in pixgan discriminator
This commit is contained in:
parent
ba6187859a
commit
14d23b9d20
|
@ -347,20 +347,24 @@ class SRGANModel(BaseModel):
|
|||
SWAP_MAX_DIM = w // 4
|
||||
SWAP_MIN_DIM = 16
|
||||
assert SWAP_MAX_DIM > 0
|
||||
random_swap_count = random.randint(0, 4)
|
||||
for i in range(random_swap_count):
|
||||
# Make the swap across fake_H and var_ref
|
||||
swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM)
|
||||
swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM)
|
||||
if swap_x + swap_w > w:
|
||||
swap_w = w - swap_x
|
||||
if swap_y + swap_h > h:
|
||||
swap_h = h - swap_y
|
||||
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
||||
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
||||
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
|
||||
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
|
||||
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
||||
if random.random() > .5: # Make this only happen half the time. Earlier experiments had it happen
|
||||
# more often and the model was "cheating" by using the presence of
|
||||
# easily discriminated fake swaps to count the entire generated image
|
||||
# as fake.
|
||||
random_swap_count = random.randint(0, 4)
|
||||
for i in range(random_swap_count):
|
||||
# Make the swap across fake_H and var_ref
|
||||
swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM)
|
||||
swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM)
|
||||
if swap_x + swap_w > w:
|
||||
swap_w = w - swap_x
|
||||
if swap_y + swap_h > h:
|
||||
swap_h = h - swap_y
|
||||
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
||||
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
||||
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
|
||||
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
|
||||
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
||||
|
||||
# Interpolate down to the dimensionality that the discriminator uses.
|
||||
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear")
|
||||
|
|
|
@ -11,12 +11,12 @@ from switched_conv_util import save_attention_to_image
|
|||
|
||||
|
||||
class MultiConvBlock(nn.Module):
|
||||
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False, weight_init_factor=1):
|
||||
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, norm=False, weight_init_factor=1):
|
||||
assert depth >= 2
|
||||
super(MultiConvBlock, self).__init__()
|
||||
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
||||
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=bn, bias=False, weight_init_factor=weight_init_factor)] +
|
||||
[ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=bn, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] +
|
||||
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] +
|
||||
[ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] +
|
||||
[ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)])
|
||||
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
||||
self.bias = nn.Parameter(torch.zeros(1))
|
||||
|
@ -167,7 +167,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
|
||||
for _ in range(switch_depth):
|
||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts)
|
||||
pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False, weight_init_factor=.1)
|
||||
pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
|
||||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
|
||||
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
|
|
|
@ -78,11 +78,7 @@ def test_numeric_stability(mod: nn.Module, format, iterations=50, device='cuda')
|
|||
stds.append(torch.std(measure).detach())
|
||||
return torch.stack(means), torch.stack(stds)
|
||||
|
||||
'''
|
||||
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||
heightened_final_step=50000, upsample_factor=1, add_scalable_noise_to_transforms=False):
|
||||
'''
|
||||
|
||||
if __name__ == "__main__":
|
||||
'''
|
||||
test_stability(functools.partial(nsg.NestedSwitchedGenerator,
|
||||
|
@ -107,8 +103,8 @@ if __name__ == "__main__":
|
|||
trans_layers=4,
|
||||
transformation_filters=64,
|
||||
upsample_factor=4),
|
||||
torch.randn(1, 3, 64, 64),
|
||||
device='cuda')
|
||||
torch.randn(1, 3, 32, 32),
|
||||
device='cpu')
|
||||
|
||||
'''
|
||||
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
|
||||
|
|
Loading…
Reference in New Issue
Block a user