forked from mrq/DL-Art-School
SRG1 conjoined except ConvBnRelu
This commit is contained in:
parent
c58c2b09ca
commit
416538f31c
|
@ -22,6 +22,15 @@ class ConvBnLelu(nn.Module):
|
|||
else:
|
||||
self.lelu = None
|
||||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
||||
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn:
|
||||
|
@ -32,16 +41,15 @@ class ConvBnLelu(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class ResidualBranch(nn.Module):
|
||||
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, bn=False):
|
||||
class MultiConvBlock(nn.Module):
|
||||
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False):
|
||||
assert depth >= 2
|
||||
super(ResidualBranch, self).__init__()
|
||||
super(MultiConvBlock, self).__init__()
|
||||
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
||||
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False)] +
|
||||
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False) for i in range(depth-2)] +
|
||||
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False)])
|
||||
|
||||
self.scale = nn.Parameter(torch.ones(1))
|
||||
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
||||
self.bias = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x, noise=None):
|
||||
|
@ -139,7 +147,7 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|||
super(ConfigurableSwitchedResidualGenerator, self).__init__()
|
||||
switches = []
|
||||
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
|
||||
switches.append(SwitchComputer(3, filters, growth, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||
switches.append(SwitchComputer(3, filters, growth, functools.partial(MultiConvBlock, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||
initialize_weights(switches, 1)
|
||||
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
||||
initialize_weights([s.transforms for s in switches], .2 / len(switches))
|
||||
|
|
Loading…
Reference in New Issue
Block a user