Experiment: new init and post-switch-conv
This commit is contained in:
parent
480d1299d7
commit
d1d573de07
|
@ -177,8 +177,8 @@ class SwitchComputer(nn.Module):
|
|||
|
||||
|
||||
class ConfigurableSwitchComputer(nn.Module):
|
||||
def __init__(self, multiplexer_net, transform_block, transform_count, init_temp=20,
|
||||
enable_negative_transforms=False, add_scalable_noise_to_transforms=False):
|
||||
def __init__(self, base_filters, multiplexer_net, transform_block, transform_count, init_temp=20,
|
||||
enable_negative_transforms=False, add_scalable_noise_to_transforms=False, init_scalar=1):
|
||||
super(ConfigurableSwitchComputer, self).__init__()
|
||||
self.enable_negative_transforms = enable_negative_transforms
|
||||
|
||||
|
@ -192,10 +192,12 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
|
||||
# And the switch itself, including learned scalars
|
||||
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
||||
self.scale = nn.Parameter(torch.ones(1))
|
||||
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=False)
|
||||
self.scale = nn.Parameter(torch.full((1,), float(init_scalar)))
|
||||
self.bias = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x, output_attention_weights=False):
|
||||
identity = x
|
||||
if self.add_noise:
|
||||
rand_feature = torch.randn_like(x)
|
||||
xformed = [t.forward(x, rand_feature) for t in self.transforms]
|
||||
|
@ -209,6 +211,8 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
|
||||
|
||||
outputs, attention = self.switch(xformed, m, True)
|
||||
outputs = identity + outputs
|
||||
outputs = identity + self.post_switch_conv(outputs)
|
||||
outputs = outputs * self.scale + self.bias
|
||||
if output_attention_weights:
|
||||
return outputs, attention
|
||||
|
@ -349,20 +353,16 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
add_scalable_noise_to_transforms=False):
|
||||
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
||||
switches = []
|
||||
post_switch_proc = []
|
||||
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False)
|
||||
self.proc_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False)
|
||||
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False)
|
||||
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
||||
switches.append(ConfigurableSwitchComputer(multiplx_fn, functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||
post_switch_proc.append(ConvBnLelu(transformation_filters, transformation_filters, bn=False))
|
||||
initialize_weights(switches, 1)
|
||||
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
||||
initialize_weights([s.transforms for s in switches], .2 / len(switches))
|
||||
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers),
|
||||
trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms,
|
||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01))
|
||||
self.switches = nn.ModuleList(switches)
|
||||
initialize_weights([p for p in post_switch_proc], .01)
|
||||
self.post_switch_convs = nn.ModuleList(post_switch_proc)
|
||||
self.transformation_counts = trans_counts
|
||||
self.init_temperature = initial_temp
|
||||
self.final_temperature_step = final_temperature_step
|
||||
|
@ -376,11 +376,9 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
x = self.initial_conv(x)
|
||||
|
||||
self.attentions = []
|
||||
for i, (sw, conv) in enumerate(zip(self.switches, self.post_switch_convs)):
|
||||
sw_out, att = sw.forward(x, True)
|
||||
for i, sw in enumerate(self.switches):
|
||||
x, att = sw.forward(x, True)
|
||||
self.attentions.append(att)
|
||||
x = x + sw_out
|
||||
x = x + conv(x)
|
||||
|
||||
if self.upsample_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
||||
|
|
Loading…
Reference in New Issue
Block a user