forked from mrq/DL-Art-School
More attention fixes for switched_spsr
This commit is contained in:
parent
d02509ef97
commit
4e972144ae
|
@ -439,9 +439,9 @@ class SwitchedSpsr(nn.Module):
|
|||
switch_filters = nf
|
||||
switch_reductions = 3
|
||||
switch_processing_layers = 2
|
||||
trans_counts = 8
|
||||
self.transformation_counts = 8
|
||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions,
|
||||
switch_processing_layers, trans_counts)
|
||||
switch_processing_layers, self.transformation_counts)
|
||||
pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
|
||||
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5),
|
||||
transformation_filters, kernel_size=3, depth=3,
|
||||
|
@ -452,12 +452,12 @@ class SwitchedSpsr(nn.Module):
|
|||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=trans_counts, init_temp=10,
|
||||
transform_count=self.transformation_counts, init_temp=10,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=trans_counts, init_temp=10,
|
||||
transform_count=self.transformation_counts, init_temp=10,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.model_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
|
||||
|
@ -470,7 +470,7 @@ class SwitchedSpsr(nn.Module):
|
|||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=trans_counts, init_temp=10,
|
||||
transform_count=self.transformation_counts, init_temp=10,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
# Upsampling
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
|
@ -487,7 +487,7 @@ class SwitchedSpsr(nn.Module):
|
|||
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=trans_counts, init_temp=10,
|
||||
transform_count=self.transformation_counts, init_temp=10,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self._branch_pretrain_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
self._branch_pretrain_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
|
@ -531,7 +531,7 @@ class SwitchedSpsr(nn.Module):
|
|||
temp = max(1, 1 + self.init_temperature *
|
||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||
self.set_temperature(temp)
|
||||
if step % 50 == 0:
|
||||
if step % 10 == 0:
|
||||
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
||||
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||
|
|
Loading…
Reference in New Issue
Block a user