forked from mrq/DL-Art-School
Fix scaling bug
This commit is contained in:
parent
30653181ba
commit
6ac6c95177
|
@ -192,8 +192,10 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
|
|
||||||
# And the switch itself, including learned scalars
|
# And the switch itself, including learned scalars
|
||||||
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
||||||
|
self.switch_scale = nn.Parameter(torch.full((1,), float(init_scalar)))
|
||||||
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=False)
|
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=False)
|
||||||
self.scale = nn.Parameter(torch.full((1,), float(init_scalar)))
|
# The post_switch_conv gets a near-zero scale. The network can decide to magnify it (or not) depending on its needs.
|
||||||
|
self.psc_scale = nn.Parameter(torch.full((1,), float(1e-3)))
|
||||||
self.bias = nn.Parameter(torch.zeros(1))
|
self.bias = nn.Parameter(torch.zeros(1))
|
||||||
|
|
||||||
def forward(self, x, output_attention_weights=False):
|
def forward(self, x, output_attention_weights=False):
|
||||||
|
@ -211,9 +213,9 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
|
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
|
||||||
|
|
||||||
outputs, attention = self.switch(xformed, m, True)
|
outputs, attention = self.switch(xformed, m, True)
|
||||||
outputs = identity + outputs
|
outputs = identity + outputs * self.switch_scale
|
||||||
#outputs = identity + self.post_switch_conv(outputs)
|
outputs = identity + self.post_switch_conv(outputs) * self.psc_scale
|
||||||
outputs = outputs * self.scale + self.bias
|
outputs = outputs + self.bias
|
||||||
if output_attention_weights:
|
if output_attention_weights:
|
||||||
return outputs, attention
|
return outputs, attention
|
||||||
else:
|
else:
|
||||||
|
@ -361,11 +363,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers),
|
functools.partial(MultiConvBlock, transformation_filters, transformation_filters, transformation_filters, kernel_size=kernel, depth=layers),
|
||||||
trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms,
|
trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms,
|
||||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=1))
|
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.01))
|
||||||
|
|
||||||
initialize_weights(switches, 1)
|
|
||||||
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
|
||||||
initialize_weights([s.transforms for s in switches], .2 / len(switches))
|
|
||||||
|
|
||||||
self.switches = nn.ModuleList(switches)
|
self.switches = nn.ModuleList(switches)
|
||||||
self.transformation_counts = trans_counts
|
self.transformation_counts = trans_counts
|
||||||
|
|
Loading…
Reference in New Issue
Block a user