SRG2 revival
Big update to SRG2 architecture to pull in a lot of things that have been learned: - Use group norm instead of batch norm - Initialize the weights on the transformations low like is done in RRDB rather than using the scalar. Models live or die by their early stages, and this ones early stage is pretty weak - Transform multiplexer to use u-net like architecture. - Just use one set of configuration variables instead of a list - flat networks performed fine in this regard.
This commit is contained in:
parent
12da993da8
commit
5f2c722a10
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from models.archs.SwitchedResidualGenerator_arch import ConvBnLelu, ConvBnRelu, MultiConvBlock, initialize_weights
|
from models.archs.arch_util import ConvBnLelu, ConvBnRelu
|
||||||
|
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock
|
||||||
from switched_conv import BareConvSwitch, compute_attention_specificity
|
from switched_conv import BareConvSwitch, compute_attention_specificity
|
||||||
from switched_conv_util import save_attention_to_image
|
from switched_conv_util import save_attention_to_image
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
|
@ -4,20 +4,20 @@ from switched_conv import BareConvSwitch, compute_attention_specificity
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import functools
|
import functools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu, ConvBnSilu
|
from models.archs.arch_util import ConvBnLelu, ConvGnSilu
|
||||||
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C
|
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.archs.spinenet_arch import SpineNet
|
||||||
from switched_conv_util import save_attention_to_image
|
from switched_conv_util import save_attention_to_image
|
||||||
|
|
||||||
|
|
||||||
class MultiConvBlock(nn.Module):
|
class MultiConvBlock(nn.Module):
|
||||||
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False):
|
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False, weight_init_factor=1):
|
||||||
assert depth >= 2
|
assert depth >= 2
|
||||||
super(MultiConvBlock, self).__init__()
|
super(MultiConvBlock, self).__init__()
|
||||||
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
||||||
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False)] +
|
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=bn, bias=False, weight_init_factor=weight_init_factor)] +
|
||||||
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False) for i in range(depth-2)] +
|
[ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=bn, bias=False, weight_init_factor=weight_init_factor) for i in range(depth-2)] +
|
||||||
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False)])
|
[ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False, bias=False, weight_init_factor=weight_init_factor)])
|
||||||
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init))
|
||||||
self.bias = nn.Parameter(torch.zeros(1))
|
self.bias = nn.Parameter(torch.zeros(1))
|
||||||
|
|
||||||
|
@ -35,43 +35,56 @@ class MultiConvBlock(nn.Module):
|
||||||
class HalvingProcessingBlock(nn.Module):
|
class HalvingProcessingBlock(nn.Module):
|
||||||
def __init__(self, filters):
|
def __init__(self, filters):
|
||||||
super(HalvingProcessingBlock, self).__init__()
|
super(HalvingProcessingBlock, self).__init__()
|
||||||
self.bnconv1 = ConvBnLelu(filters, filters * 2, stride=2, bn=False, bias=False)
|
self.bnconv1 = ConvGnSilu(filters, filters * 2, stride=2, gn=False, bias=False)
|
||||||
self.bnconv2 = ConvBnLelu(filters * 2, filters * 2, bn=True, bias=False)
|
self.bnconv2 = ConvGnSilu(filters * 2, filters * 2, gn=True, bias=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.bnconv1(x)
|
x = self.bnconv1(x)
|
||||||
return self.bnconv2(x)
|
return self.bnconv2(x)
|
||||||
|
|
||||||
|
|
||||||
# Creates a nested series of convolutional blocks. Each block processes the input data in-place and adds
|
class ExpansionBlock(nn.Module):
|
||||||
# filter_growth filters. Return is (nn.Sequential, ending_filters)
|
def __init__(self, filters):
|
||||||
def create_sequential_growing_processing_block(filters_init, filter_growth, num_convs):
|
super(ExpansionBlock, self).__init__()
|
||||||
convs = []
|
self.decimate = ConvGnSilu(filters, filters // 2, kernel_size=1, bias=False, silu=False, gn=False)
|
||||||
current_filters = filters_init
|
self.conjoin = ConvGnSilu(filters, filters // 2, kernel_size=3, bias=True, silu=False, gn=True)
|
||||||
for i in range(num_convs):
|
self.process = ConvGnSilu(filters // 2, filters // 2, kernel_size=3, bias=False, silu=True, gn=True)
|
||||||
convs.append(ConvBnSilu(current_filters, current_filters + filter_growth, bn=True, bias=False))
|
|
||||||
current_filters += filter_growth
|
def forward(self, input, passthrough):
|
||||||
return nn.Sequential(*convs), current_filters
|
x = F.interpolate(input, scale_factor=2, mode="nearest")
|
||||||
|
x = self.decimate(x)
|
||||||
|
x = self.conjoin(torch.cat([x, passthrough], dim=1))
|
||||||
|
return self.process(x)
|
||||||
|
|
||||||
|
|
||||||
|
# This is a classic u-net architecture with the goal of assigning each individual pixel an individual transform
|
||||||
|
# switching set.
|
||||||
class ConvBasisMultiplexer(nn.Module):
|
class ConvBasisMultiplexer(nn.Module):
|
||||||
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth, multiplexer_channels, use_bn=True):
|
def __init__(self, input_channels, base_filters, reductions, processing_depth, multiplexer_channels, use_gn=True):
|
||||||
super(ConvBasisMultiplexer, self).__init__()
|
super(ConvBasisMultiplexer, self).__init__()
|
||||||
self.filter_conv = ConvBnSilu(input_channels, base_filters, bias=True)
|
self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True)
|
||||||
self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)]))
|
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(reductions)])
|
||||||
reduction_filters = base_filters * 2 ** reductions
|
reduction_filters = base_filters * 2 ** reductions
|
||||||
self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth)
|
self.processing_blocks = nn.Sequential(OrderedDict([('block%i' % (i,), ConvGnSilu(reduction_filters, reduction_filters, bias=False)) for i in range(processing_depth)]))
|
||||||
|
self.expansion_blocks = nn.ModuleList([ExpansionBlock(reduction_filters // (2 ** i)) for i in range(reductions)])
|
||||||
|
|
||||||
gap = self.output_filter_count - multiplexer_channels
|
gap = base_filters - multiplexer_channels
|
||||||
# Hey silly - if you're going to interpolate later, do it here instead. Then add some processing layers to let the model adjust it properly.
|
cbl1_out = ((base_filters - (gap // 2)) // 4) * 4 # Must be multiples of 4 to use with group norm.
|
||||||
self.cbl1 = ConvBnSilu(self.output_filter_count, self.output_filter_count - (gap // 2), bn=use_bn, bias=False)
|
self.cbl1 = ConvGnSilu(base_filters, cbl1_out, gn=use_gn, bias=False, num_groups=4)
|
||||||
self.cbl2 = ConvBnSilu(self.output_filter_count - (gap // 2), self.output_filter_count - (3 * gap // 4), bn=use_bn, bias=False)
|
cbl2_out = ((base_filters - (3 * gap // 4)) // 4) * 4
|
||||||
self.cbl3 = ConvBnSilu(self.output_filter_count - (3 * gap // 4), multiplexer_channels, bias=True)
|
self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, gn=use_gn, bias=False, num_groups=4)
|
||||||
|
self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, gn=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.filter_conv(x)
|
x = self.filter_conv(x)
|
||||||
x = self.reduction_blocks(x)
|
reduction_identities = []
|
||||||
|
for b in self.reduction_blocks:
|
||||||
|
reduction_identities.append(x)
|
||||||
|
x = b(x)
|
||||||
x = self.processing_blocks(x)
|
x = self.processing_blocks(x)
|
||||||
|
for i, b in enumerate(self.expansion_blocks):
|
||||||
|
x = b(x, reduction_identities[-i - 1])
|
||||||
|
|
||||||
x = self.cbl1(x)
|
x = self.cbl1(x)
|
||||||
x = self.cbl2(x)
|
x = self.cbl2(x)
|
||||||
x = self.cbl3(x)
|
x = self.cbl3(x)
|
||||||
|
@ -94,13 +107,13 @@ class BackboneMultiplexer(nn.Module):
|
||||||
def __init__(self, backbone: CachedBackboneWrapper, transform_count):
|
def __init__(self, backbone: CachedBackboneWrapper, transform_count):
|
||||||
super(BackboneMultiplexer, self).__init__()
|
super(BackboneMultiplexer, self).__init__()
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.proc = nn.Sequential(ConvBnSilu(256, 256, kernel_size=3, bias=True),
|
self.proc = nn.Sequential(ConvGnSilu(256, 256, kernel_size=3, bias=True),
|
||||||
ConvBnSilu(256, 256, kernel_size=3, bias=False))
|
ConvGnSilu(256, 256, kernel_size=3, bias=False))
|
||||||
self.up1 = nn.Sequential(ConvBnSilu(256, 128, kernel_size=3, bias=False, bn=False, silu=False),
|
self.up1 = nn.Sequential(ConvGnSilu(256, 128, kernel_size=3, bias=False, gn=False, silu=False),
|
||||||
ConvBnSilu(128, 128, kernel_size=3, bias=False))
|
ConvGnSilu(128, 128, kernel_size=3, bias=False))
|
||||||
self.up2 = nn.Sequential(ConvBnSilu(128, 64, kernel_size=3, bias=False, bn=False, silu=False),
|
self.up2 = nn.Sequential(ConvGnSilu(128, 64, kernel_size=3, bias=False, gn=False, silu=False),
|
||||||
ConvBnSilu(64, 64, kernel_size=3, bias=False))
|
ConvGnSilu(64, 64, kernel_size=3, bias=False))
|
||||||
self.final = ConvBnSilu(64, transform_count, bias=False, bn=False, silu=False)
|
self.final = ConvGnSilu(64, transform_count, bias=False, gn=False, silu=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
spine = self.backbone.get_forward_result()
|
spine = self.backbone.get_forward_result()
|
||||||
|
@ -112,13 +125,10 @@ class BackboneMultiplexer(nn.Module):
|
||||||
|
|
||||||
class ConfigurableSwitchComputer(nn.Module):
|
class ConfigurableSwitchComputer(nn.Module):
|
||||||
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20,
|
def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, init_temp=20,
|
||||||
enable_negative_transforms=False, add_scalable_noise_to_transforms=False, init_scalar=1):
|
add_scalable_noise_to_transforms=False):
|
||||||
super(ConfigurableSwitchComputer, self).__init__()
|
super(ConfigurableSwitchComputer, self).__init__()
|
||||||
self.enable_negative_transforms = enable_negative_transforms
|
|
||||||
|
|
||||||
tc = transform_count
|
tc = transform_count
|
||||||
if self.enable_negative_transforms:
|
|
||||||
tc = transform_count * 2
|
|
||||||
self.multiplexer = multiplexer_net(tc)
|
self.multiplexer = multiplexer_net(tc)
|
||||||
|
|
||||||
self.pre_transform = pre_transform_block()
|
self.pre_transform = pre_transform_block()
|
||||||
|
@ -128,11 +138,11 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
|
|
||||||
# And the switch itself, including learned scalars
|
# And the switch itself, including learned scalars
|
||||||
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
||||||
self.switch_scale = nn.Parameter(torch.full((1,), float(init_scalar)))
|
self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
|
||||||
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=False)
|
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, bn=False, bias=True)
|
||||||
# The post_switch_conv gets a near-zero scale. The network can decide to magnify it (or not) depending on its needs.
|
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
||||||
self.psc_scale = nn.Parameter(torch.full((1,), float(1e-3)))
|
# depending on its needs.
|
||||||
self.bias = nn.Parameter(torch.zeros(1))
|
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
||||||
|
|
||||||
def forward(self, x, output_attention_weights=False):
|
def forward(self, x, output_attention_weights=False):
|
||||||
identity = x
|
identity = x
|
||||||
|
@ -142,17 +152,12 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
|
|
||||||
x = self.pre_transform(x)
|
x = self.pre_transform(x)
|
||||||
xformed = [t.forward(x) for t in self.transforms]
|
xformed = [t.forward(x) for t in self.transforms]
|
||||||
if self.enable_negative_transforms:
|
|
||||||
xformed.extend([-t for t in xformed])
|
|
||||||
|
|
||||||
m = self.multiplexer(identity)
|
m = self.multiplexer(identity)
|
||||||
# Interpolate the multiplexer across the entire shape of the image.
|
|
||||||
m = F.interpolate(m, size=xformed[0].shape[2:], mode='nearest')
|
|
||||||
|
|
||||||
outputs, attention = self.switch(xformed, m, True)
|
outputs, attention = self.switch(xformed, m, True)
|
||||||
outputs = identity + outputs * self.switch_scale
|
outputs = identity + outputs * self.switch_scale
|
||||||
outputs = identity + self.post_switch_conv(outputs) * self.psc_scale
|
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale
|
||||||
outputs = outputs + self.bias
|
|
||||||
if output_attention_weights:
|
if output_attention_weights:
|
||||||
return outputs, attention
|
return outputs, attention
|
||||||
else:
|
else:
|
||||||
|
@ -163,25 +168,25 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||||
heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False,
|
heightened_final_step=50000, upsample_factor=1,
|
||||||
add_scalable_noise_to_transforms=False):
|
add_scalable_noise_to_transforms=False):
|
||||||
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
||||||
switches = []
|
switches = []
|
||||||
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False, lelu=False, bias=True)
|
self.initial_conv = ConvBnLelu(3, transformation_filters, bn=False, lelu=False, bias=True)
|
||||||
self.sw_conv = ConvBnLelu(transformation_filters, transformation_filters, lelu=False, bias=True)
|
|
||||||
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
||||||
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
||||||
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, bn=False, bias=True)
|
||||||
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False, bias=True)
|
self.final_conv = ConvBnLelu(transformation_filters, 3, bn=False, lelu=False, bias=True)
|
||||||
for filters, growth, sw_reduce, sw_proc, trans_count, kernel, layers in zip(switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers):
|
for _ in range(switch_depth):
|
||||||
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, filters, growth, sw_reduce, sw_proc, trans_count)
|
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts)
|
||||||
|
pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False, weight_init_factor=.1)
|
||||||
|
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
|
||||||
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=functools.partial(ConvBnLelu, transformation_filters, transformation_filters, bn=False, bias=False),
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
transform_block=functools.partial(MultiConvBlock, transformation_filters, transformation_filters + growth, transformation_filters, kernel_size=kernel, depth=layers),
|
transform_count=trans_counts, init_temp=initial_temp,
|
||||||
transform_count=trans_count, init_temp=initial_temp, enable_negative_transforms=enable_negative_transforms,
|
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms, init_scalar=.1))
|
|
||||||
|
|
||||||
self.switches = nn.ModuleList(switches)
|
self.switches = nn.ModuleList(switches)
|
||||||
self.transformation_counts = trans_counts
|
self.transformation_counts = trans_counts
|
||||||
|
@ -268,7 +273,7 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
|
||||||
pretransform_fn = functools.partial(nn.Sequential, ConvBnLelu(base_filters, base_filters, kernel_size=3, bn=False, lelu=False, bias=False))
|
pretransform_fn = functools.partial(nn.Sequential, ConvBnLelu(base_filters, base_filters, kernel_size=3, bn=False, lelu=False, bias=False))
|
||||||
transform_fn = functools.partial(MultiConvBlock, base_filters, int(base_filters * 1.5), base_filters, kernel_size=3, depth=4)
|
transform_fn = functools.partial(MultiConvBlock, base_filters, int(base_filters * 1.5), base_filters, kernel_size=3, depth=4)
|
||||||
self.switch = ConfigurableSwitchComputer(base_filters, multiplx_fn, pretransform_fn, transform_fn, trans_count, init_temp=initial_temp,
|
self.switch = ConfigurableSwitchComputer(base_filters, multiplx_fn, pretransform_fn, transform_fn, trans_count, init_temp=initial_temp,
|
||||||
enable_negative_transforms=False, add_scalable_noise_to_transforms=True, init_scalar=.1)
|
add_scalable_noise_to_transforms=True, init_scalar=.1)
|
||||||
|
|
||||||
self.transformation_counts = trans_count
|
self.transformation_counts = trans_count
|
||||||
self.init_temperature = initial_temp
|
self.init_temperature = initial_temp
|
||||||
|
|
|
@ -219,7 +219,7 @@ class ConvBnRelu(nn.Module):
|
||||||
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvBnSilu(nn.Module):
|
class ConvBnSilu(nn.Module):
|
||||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, bn=True, bias=True):
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, bn=True, bias=True, weight_init_factor=1):
|
||||||
super(ConvBnSilu, self).__init__()
|
super(ConvBnSilu, self).__init__()
|
||||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
assert kernel_size in padding_map.keys()
|
assert kernel_size in padding_map.keys()
|
||||||
|
@ -237,6 +237,9 @@ class ConvBnSilu(nn.Module):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
||||||
|
m.weight.data *= weight_init_factor
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.weight, 1)
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
@ -254,7 +257,7 @@ class ConvBnSilu(nn.Module):
|
||||||
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
|
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
|
||||||
kernel sizes. '''
|
kernel sizes. '''
|
||||||
class ConvBnLelu(nn.Module):
|
class ConvBnLelu(nn.Module):
|
||||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True):
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True, weight_init_factor=1):
|
||||||
super(ConvBnLelu, self).__init__()
|
super(ConvBnLelu, self).__init__()
|
||||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
assert kernel_size in padding_map.keys()
|
assert kernel_size in padding_map.keys()
|
||||||
|
@ -273,6 +276,9 @@ class ConvBnLelu(nn.Module):
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
||||||
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
||||||
|
m.weight.data *= weight_init_factor
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.weight, 1)
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
@ -319,5 +325,42 @@ class ConvGnLelu(nn.Module):
|
||||||
x = self.gn(x)
|
x = self.gn(x)
|
||||||
if self.lelu:
|
if self.lelu:
|
||||||
return self.lelu(x)
|
return self.lelu(x)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||||
|
kernel sizes. '''
|
||||||
|
class ConvGnSilu(nn.Module):
|
||||||
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, gn=True, bias=True, num_groups=8, weight_init_factor=1):
|
||||||
|
super(ConvGnSilu, self).__init__()
|
||||||
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||||
|
assert kernel_size in padding_map.keys()
|
||||||
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||||
|
if gn:
|
||||||
|
self.gn = nn.GroupNorm(num_groups, filters_out)
|
||||||
|
else:
|
||||||
|
self.gn = None
|
||||||
|
if silu:
|
||||||
|
self.silu = SiLU()
|
||||||
|
else:
|
||||||
|
self.silu = None
|
||||||
|
|
||||||
|
# Init params.
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
||||||
|
m.weight.data *= weight_init_factor
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.gn:
|
||||||
|
x = self.gn(x)
|
||||||
|
if self.silu:
|
||||||
|
return self.silu(x)
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
|
@ -59,7 +59,7 @@ def define_G(opt, net_key='network_G'):
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||||
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
||||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'],
|
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||||
switch_reductions=opt_net['switch_reductions'],
|
switch_reductions=opt_net['switch_reductions'],
|
||||||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_rrdb.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_srg2.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
|
@ -97,20 +97,19 @@ if __name__ == "__main__":
|
||||||
torch.randn(1, 3, 64, 64),
|
torch.randn(1, 3, 64, 64),
|
||||||
device='cuda')
|
device='cuda')
|
||||||
'''
|
'''
|
||||||
'''
|
|
||||||
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
|
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
|
||||||
switch_filters=[32,32,32,32],
|
switch_depth=4,
|
||||||
switch_growths=[16,16,16,16],
|
switch_filters=64,
|
||||||
switch_reductions=[4,3,2,1],
|
switch_reductions=4,
|
||||||
switch_processing_layers=[3,3,4,5],
|
switch_processing_layers=2,
|
||||||
trans_counts=[16,16,16,16,16],
|
trans_counts=8,
|
||||||
trans_kernel_sizes=[3,3,3,3,3],
|
trans_kernel_sizes=3,
|
||||||
trans_layers=[3,3,3,3,3],
|
trans_layers=4,
|
||||||
transformation_filters=64,
|
transformation_filters=64,
|
||||||
initial_temp=10),
|
upsample_factor=4),
|
||||||
torch.randn(1, 3, 64, 64),
|
torch.randn(1, 3, 64, 64),
|
||||||
device='cuda')
|
device='cuda')
|
||||||
'''
|
|
||||||
'''
|
'''
|
||||||
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
|
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
|
||||||
switch_filters=[32,32,32,32],
|
switch_filters=[32,32,32,32],
|
||||||
|
@ -125,7 +124,9 @@ if __name__ == "__main__":
|
||||||
torch.randn(1, 3, 64, 64),
|
torch.randn(1, 3, 64, 64),
|
||||||
device='cuda')
|
device='cuda')
|
||||||
'''
|
'''
|
||||||
|
'''
|
||||||
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3,
|
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3,
|
||||||
64, 16),
|
64, 16),
|
||||||
torch.randn(1, 3, 64, 64),
|
torch.randn(1, 3, 64, 64),
|
||||||
device='cuda')
|
device='cuda')
|
||||||
|
'''
|
Loading…
Reference in New Issue
Block a user