forked from mrq/DL-Art-School
More attention fixes for switched_spsr
This commit is contained in:
parent
d02509ef97
commit
4e972144ae
|
@ -439,9 +439,9 @@ class SwitchedSpsr(nn.Module):
|
||||||
switch_filters = nf
|
switch_filters = nf
|
||||||
switch_reductions = 3
|
switch_reductions = 3
|
||||||
switch_processing_layers = 2
|
switch_processing_layers = 2
|
||||||
trans_counts = 8
|
self.transformation_counts = 8
|
||||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions,
|
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions,
|
||||||
switch_processing_layers, trans_counts)
|
switch_processing_layers, self.transformation_counts)
|
||||||
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
|
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
|
||||||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
|
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
|
||||||
transformation_filters, kernel_size=3, depth=3,
|
transformation_filters, kernel_size=3, depth=3,
|
||||||
|
@ -452,12 +452,12 @@ class SwitchedSpsr(nn.Module):
|
||||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
transform_count=trans_counts, init_temp=10,
|
transform_count=self.transformation_counts, init_temp=10,
|
||||||
add_scalable_noise_to_transforms=True)
|
add_scalable_noise_to_transforms=True)
|
||||||
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
transform_count=trans_counts, init_temp=10,
|
transform_count=self.transformation_counts, init_temp=10,
|
||||||
add_scalable_noise_to_transforms=True)
|
add_scalable_noise_to_transforms=True)
|
||||||
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||||
self.model_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
|
self.model_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
|
||||||
|
@ -470,7 +470,7 @@ class SwitchedSpsr(nn.Module):
|
||||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
transform_count=trans_counts, init_temp=10,
|
transform_count=self.transformation_counts, init_temp=10,
|
||||||
add_scalable_noise_to_transforms=True)
|
add_scalable_noise_to_transforms=True)
|
||||||
# Upsampling
|
# Upsampling
|
||||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||||
|
@ -487,7 +487,7 @@ class SwitchedSpsr(nn.Module):
|
||||||
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
transform_count=trans_counts, init_temp=10,
|
transform_count=self.transformation_counts, init_temp=10,
|
||||||
add_scalable_noise_to_transforms=True)
|
add_scalable_noise_to_transforms=True)
|
||||||
self._branch_pretrain_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
self._branch_pretrain_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||||
self._branch_pretrain_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
self._branch_pretrain_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
||||||
|
@ -531,7 +531,7 @@ class SwitchedSpsr(nn.Module):
|
||||||
temp = max(1, 1 + self.init_temperature *
|
temp = max(1, 1 + self.init_temperature *
|
||||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||||
self.set_temperature(temp)
|
self.set_temperature(temp)
|
||||||
if step % 50 == 0:
|
if step % 10 == 0:
|
||||||
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
||||||
prefix = "attention_map_%i_%%i.png" % (step,)
|
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user