forked from mrq/DL-Art-School
Add SpineNet & integrate with SRG
New version of SRG uses SpineNet for a switch backbone.
This commit is contained in:
parent
3ed7a2b9ab
commit
703dec4472
|
@ -4,79 +4,11 @@ from switched_conv import BareConvSwitch, compute_attention_specificity
|
|||
import torch.nn.functional as F
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
from models.archs.arch_util import initialize_weights
|
||||
from models.archs.arch_util import initialize_weights, ConvBnRelu, ConvBnLelu
|
||||
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
from switched_conv_util import save_attention_to_image
|
||||
|
||||
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
class ConvBnRelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True):
|
||||
super(ConvBnRelu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
if bn:
|
||||
self.bn = nn.BatchNorm2d(filters_out)
|
||||
else:
|
||||
self.bn = None
|
||||
if relu:
|
||||
self.relu = nn.ReLU()
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn:
|
||||
x = self.bn(x)
|
||||
if self.relu:
|
||||
return self.relu(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
class ConvBnLelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True):
|
||||
super(ConvBnLelu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
if bn:
|
||||
self.bn = nn.BatchNorm2d(filters_out)
|
||||
else:
|
||||
self.bn = None
|
||||
if lelu:
|
||||
self.lelu = nn.LeakyReLU(negative_slope=.1)
|
||||
else:
|
||||
self.lelu = None
|
||||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out', nonlinearity='leaky_relu' if self.lelu else 'linear')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn:
|
||||
x = self.bn(x)
|
||||
if self.lelu:
|
||||
return self.lelu(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class MultiConvBlock(nn.Module):
|
||||
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, bn=False):
|
||||
|
@ -214,7 +146,7 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
|
||||
m = self.multiplexer(identity)
|
||||
# Interpolate the multiplexer across the entire shape of the image.
|
||||
m = F.interpolate(m, size=x.shape[2:], mode='nearest')
|
||||
m = F.interpolate(m, size=xformed[0].shape[2:], mode='nearest')
|
||||
|
||||
outputs, attention = self.switch(xformed, m, True)
|
||||
outputs = identity + outputs * self.switch_scale
|
||||
|
@ -252,6 +184,22 @@ class ConvBasisMultiplexer(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class SpineNetMultiplexer(nn.Module):
|
||||
def __init__(self, input_channels, transform_count):
|
||||
super(SpineNetMultiplexer, self).__init__()
|
||||
self.backbone = SpineNet('49', in_channels=input_channels)
|
||||
self.rdc1 = ConvBnRelu(256, 128, kernel_size=3, bias=False)
|
||||
self.rdc2 = ConvBnRelu(128, 64, kernel_size=3, bias=False)
|
||||
self.rdc3 = ConvBnRelu(64, transform_count, bias=False, bn=False, relu=False)
|
||||
|
||||
def forward(self, x):
|
||||
spine = self.backbone(x)
|
||||
feat = self.rdc1(spine[0])
|
||||
feat = self.rdc2(feat)
|
||||
feat = self.rdc3(feat)
|
||||
return feat
|
||||
|
||||
|
||||
class ConvBasisMultiplexerReducer(nn.Module):
|
||||
def __init__(self, input_channels, base_filters, growth, reductions, processing_depth):
|
||||
super(ConvBasisMultiplexerReducer, self).__init__()
|
||||
|
@ -425,3 +373,94 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
def __init__(self, factor):
|
||||
super(Interpolate, self).__init__()
|
||||
self.factor = factor
|
||||
|
||||
def forward(self, x):
|
||||
return F.interpolate(x, scale_factor=self.factor)
|
||||
|
||||
|
||||
class ConfigurableSwitchedResidualGenerator3(nn.Module):
|
||||
def __init__(self, trans_counts,
|
||||
trans_kernel_sizes,
|
||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000,
|
||||
heightened_temp_min=1,
|
||||
heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False,
|
||||
add_scalable_noise_to_transforms=False):
|
||||
super(ConfigurableSwitchedResidualGenerator3, self).__init__()
|
||||
switches = []
|
||||
for trans_count, kernel, layers in zip(trans_counts, trans_kernel_sizes, trans_layers):
|
||||
multiplx_fn = functools.partial(SpineNetMultiplexer, 3)
|
||||
switches.append(ConfigurableSwitchComputer(base_filters=3, multiplexer_net=multiplx_fn,
|
||||
pre_transform_block=functools.partial(nn.Sequential,
|
||||
ConvBnLelu(3, transformation_filters, kernel_size=1, stride=4, bn=False, lelu=False, bias=False),
|
||||
ResidualDenseBlock_5C(
|
||||
transformation_filters),
|
||||
ResidualDenseBlock_5C(
|
||||
transformation_filters)),
|
||||
transform_block=functools.partial(nn.Sequential,
|
||||
ResidualDenseBlock_5C(transformation_filters),
|
||||
Interpolate(4),
|
||||
ConvBnLelu(transformation_filters, transformation_filters // 2, kernel_size=3, bias=False, bn=False),
|
||||
ConvBnLelu(transformation_filters // 2, 3, kernel_size=1, bias=False, bn=False, lelu=False)),
|
||||
transform_count=trans_count, init_temp=initial_temp,
|
||||
enable_negative_transforms=enable_negative_transforms,
|
||||
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms,
|
||||
init_scalar=.01))
|
||||
|
||||
self.switches = nn.ModuleList(switches)
|
||||
self.transformation_counts = trans_counts
|
||||
self.init_temperature = initial_temp
|
||||
self.final_temperature_step = final_temperature_step
|
||||
self.heightened_temp_min = heightened_temp_min
|
||||
self.heightened_final_step = heightened_final_step
|
||||
self.attentions = None
|
||||
self.upsample_factor = upsample_factor
|
||||
|
||||
def forward(self, x):
|
||||
if self.upsample_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
||||
|
||||
self.attentions = []
|
||||
for i, sw in enumerate(self.switches):
|
||||
x, att = sw.forward(x, True)
|
||||
self.attentions.append(att)
|
||||
|
||||
return x,
|
||||
|
||||
def set_temperature(self, temp):
|
||||
[sw.set_temperature(temp) for sw in self.switches]
|
||||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
temp = max(1, int(
|
||||
self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
||||
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
|
||||
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
||||
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
||||
h_steps_total = self.heightened_final_step - self.final_temperature_step
|
||||
h_steps_current = min(step - self.final_temperature_step, h_steps_total)
|
||||
# The "gap" will represent the steps that need to be traveled as a linear function.
|
||||
h_gap = 1 / self.heightened_temp_min
|
||||
temp = h_gap * h_steps_current / h_steps_total
|
||||
# Invert temperature to represent reality on this side of the curve
|
||||
temp = 1 / temp
|
||||
self.set_temperature(temp)
|
||||
if step % 50 == 0:
|
||||
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step,
|
||||
"a%i" % (i + 1,)) for i in range(len(self.switches))]
|
||||
|
||||
def get_debug_values(self, step):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||
val = {"switch_temperature": temp}
|
||||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
319
codes/models/archs/spinenet_arch.py
Normal file
319
codes/models/archs/spinenet_arch.py
Normal file
|
@ -0,0 +1,319 @@
|
|||
# Taken and modified from https://github.com/lucifer443/SpineNet-Pytorch/blob/master/mmdet/models/backbones/spinenet.py
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import kaiming_normal
|
||||
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from models.archs.arch_util import ConvBnRelu
|
||||
|
||||
def constant_init(module, val, bias=0):
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
nn.init.constant_(module.weight, val)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
def kaiming_init(module,
|
||||
a=0,
|
||||
mode='fan_out',
|
||||
nonlinearity='relu',
|
||||
bias=0,
|
||||
distribution='normal'):
|
||||
assert distribution in ['uniform', 'normal']
|
||||
if distribution == 'uniform':
|
||||
nn.init.kaiming_uniform_(
|
||||
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
else:
|
||||
nn.init.kaiming_normal_(
|
||||
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
FILTER_SIZE_MAP = {
|
||||
1: 32,
|
||||
2: 64,
|
||||
3: 128,
|
||||
4: 256,
|
||||
5: 256,
|
||||
6: 256,
|
||||
7: 256,
|
||||
}
|
||||
|
||||
def make_res_layer(block,
|
||||
inplanes,
|
||||
planes,
|
||||
blocks,
|
||||
stride=1,
|
||||
dilation=1):
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(
|
||||
inplanes=inplanes,
|
||||
planes=planes,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
downsample=downsample))
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
block(
|
||||
inplanes=inplanes,
|
||||
planes=planes,
|
||||
stride=1,
|
||||
dilation=dilation))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
# The fixed SpineNet architecture discovered by NAS.
|
||||
# Each element represents a specification of a building block:
|
||||
# (block_level, block_fn, (input_offset0, input_offset1), is_output).
|
||||
SPINENET_BLOCK_SPECS = [
|
||||
(2, Bottleneck, (None, None), False), # init block
|
||||
(2, Bottleneck, (None, None), False), # init block
|
||||
(2, Bottleneck, (0, 1), False),
|
||||
(4, BasicBlock, (0, 1), False),
|
||||
(3, Bottleneck, (2, 3), False),
|
||||
(4, Bottleneck, (2, 4), False),
|
||||
(6, BasicBlock, (3, 5), False),
|
||||
(4, Bottleneck, (3, 5), False),
|
||||
(5, BasicBlock, (6, 7), False),
|
||||
(7, BasicBlock, (6, 8), False),
|
||||
(5, Bottleneck, (8, 9), False),
|
||||
(5, Bottleneck, (8, 10), False),
|
||||
(4, Bottleneck, (5, 10), True),
|
||||
(3, Bottleneck, (4, 10), True),
|
||||
(5, Bottleneck, (7, 12), True),
|
||||
(7, Bottleneck, (5, 14), True),
|
||||
(6, Bottleneck, (12, 14), True),
|
||||
]
|
||||
|
||||
SCALING_MAP = {
|
||||
'49S': {
|
||||
'endpoints_num_filters': 128,
|
||||
'filter_size_scale': 0.65,
|
||||
'resample_alpha': 0.5,
|
||||
'block_repeats': 1,
|
||||
},
|
||||
'49': {
|
||||
'endpoints_num_filters': 256,
|
||||
'filter_size_scale': 1.0,
|
||||
'resample_alpha': 0.5,
|
||||
'block_repeats': 1,
|
||||
},
|
||||
'96': {
|
||||
'endpoints_num_filters': 256,
|
||||
'filter_size_scale': 1.0,
|
||||
'resample_alpha': 0.5,
|
||||
'block_repeats': 2,
|
||||
},
|
||||
'143': {
|
||||
'endpoints_num_filters': 256,
|
||||
'filter_size_scale': 1.0,
|
||||
'resample_alpha': 1.0,
|
||||
'block_repeats': 3,
|
||||
},
|
||||
'190': {
|
||||
'endpoints_num_filters': 512,
|
||||
'filter_size_scale': 1.3,
|
||||
'resample_alpha': 1.0,
|
||||
'block_repeats': 4,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class BlockSpec(object):
|
||||
"""A container class that specifies the block configuration for SpineNet."""
|
||||
|
||||
def __init__(self, level, block_fn, input_offsets, is_output):
|
||||
self.level = level
|
||||
self.block_fn = block_fn
|
||||
self.input_offsets = input_offsets
|
||||
self.is_output = is_output
|
||||
|
||||
|
||||
def build_block_specs(block_specs=None):
|
||||
"""Builds the list of BlockSpec objects for SpineNet."""
|
||||
if not block_specs:
|
||||
block_specs = SPINENET_BLOCK_SPECS
|
||||
return [BlockSpec(*b) for b in block_specs]
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, scale, block_type, alpha=1.0):
|
||||
super(Resample, self).__init__()
|
||||
self.scale = scale
|
||||
new_in_channels = int(in_channels * alpha)
|
||||
if block_type == Bottleneck:
|
||||
in_channels *= 4
|
||||
self.squeeze_conv = ConvBnRelu(in_channels, new_in_channels, kernel_size=1)
|
||||
if scale < 1:
|
||||
self.downsample_conv = ConvBnRelu(new_in_channels, new_in_channels, kernel_size=3, stride=2)
|
||||
self.expand_conv = ConvBnRelu(new_in_channels, out_channels, kernel_size=1, relu=False)
|
||||
|
||||
def _resize(self, x):
|
||||
if self.scale == 1:
|
||||
return x
|
||||
elif self.scale > 1:
|
||||
return F.interpolate(x, scale_factor=self.scale, mode='nearest')
|
||||
else:
|
||||
x = self.downsample_conv(x)
|
||||
if self.scale < 0.5:
|
||||
new_kernel_size = 3 if self.scale >= 0.25 else 5
|
||||
x = F.max_pool2d(x, kernel_size=new_kernel_size, stride=int(0.5/self.scale), padding=new_kernel_size//2)
|
||||
return x
|
||||
|
||||
def forward(self, inputs):
|
||||
feat = self.squeeze_conv(inputs)
|
||||
feat = self._resize(feat)
|
||||
feat = self.expand_conv(feat)
|
||||
return feat
|
||||
|
||||
|
||||
class Merge(nn.Module):
|
||||
"""Merge two input tensors"""
|
||||
def __init__(self, block_spec, alpha, filter_size_scale):
|
||||
super(Merge, self).__init__()
|
||||
out_channels = int(FILTER_SIZE_MAP[block_spec.level] * filter_size_scale)
|
||||
if block_spec.block_fn == Bottleneck:
|
||||
out_channels *= 4
|
||||
self.block = block_spec.block_fn
|
||||
self.resample_ops = nn.ModuleList()
|
||||
for spec_idx in block_spec.input_offsets:
|
||||
spec = BlockSpec(*SPINENET_BLOCK_SPECS[spec_idx])
|
||||
in_channels = int(FILTER_SIZE_MAP[spec.level] * filter_size_scale)
|
||||
scale = 2**(spec.level - block_spec.level)
|
||||
self.resample_ops.append(
|
||||
Resample(in_channels, out_channels, scale, spec.block_fn, alpha)
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.resample_ops)
|
||||
parent0_feat = self.resample_ops[0](inputs[0])
|
||||
parent1_feat = self.resample_ops[1](inputs[1])
|
||||
target_feat = parent0_feat + parent1_feat
|
||||
return target_feat
|
||||
|
||||
|
||||
class SpineNet(nn.Module):
|
||||
"""Class to build SpineNet backbone"""
|
||||
def __init__(self,
|
||||
arch,
|
||||
in_channels=3,
|
||||
output_level=[3, 4, 5, 6, 7],
|
||||
zero_init_residual=True):
|
||||
super(SpineNet, self).__init__()
|
||||
self._block_specs = build_block_specs()[2:]
|
||||
self._endpoints_num_filters = SCALING_MAP[arch]['endpoints_num_filters']
|
||||
self._resample_alpha = SCALING_MAP[arch]['resample_alpha']
|
||||
self._block_repeats = SCALING_MAP[arch]['block_repeats']
|
||||
self._filter_size_scale = SCALING_MAP[arch]['filter_size_scale']
|
||||
self._init_block_fn = Bottleneck
|
||||
self._num_init_blocks = 2
|
||||
self.zero_init_residual = zero_init_residual
|
||||
assert min(output_level) > 2 and max(output_level) < 8, "Output level out of range"
|
||||
self.output_level = output_level
|
||||
|
||||
self._make_stem_layer(in_channels)
|
||||
self._make_scale_permuted_network()
|
||||
self._make_endpoints()
|
||||
|
||||
def _make_stem_layer(self, in_channels):
|
||||
"""Build the stem network."""
|
||||
# Build the first conv and maxpooling layers.
|
||||
self.conv1 = ConvBnRelu(
|
||||
in_channels,
|
||||
64,
|
||||
kernel_size=7,
|
||||
stride=2) # Original paper had stride=2 and a maxpool after.
|
||||
|
||||
# Build the initial level 2 blocks.
|
||||
self.init_block1 = make_res_layer(
|
||||
self._init_block_fn,
|
||||
64,
|
||||
int(FILTER_SIZE_MAP[2] * self._filter_size_scale),
|
||||
self._block_repeats)
|
||||
self.init_block2 = make_res_layer(
|
||||
self._init_block_fn,
|
||||
int(FILTER_SIZE_MAP[2] * self._filter_size_scale) * 4,
|
||||
int(FILTER_SIZE_MAP[2] * self._filter_size_scale),
|
||||
self._block_repeats)
|
||||
|
||||
def _make_endpoints(self):
|
||||
self.endpoint_convs = nn.ModuleDict()
|
||||
for block_spec in self._block_specs:
|
||||
if block_spec.is_output:
|
||||
in_channels = int(FILTER_SIZE_MAP[block_spec.level]*self._filter_size_scale) * 4
|
||||
self.endpoint_convs[str(block_spec.level)] = ConvBnRelu(in_channels,
|
||||
self._endpoints_num_filters,
|
||||
kernel_size=1,
|
||||
relu=False)
|
||||
|
||||
def _make_scale_permuted_network(self):
|
||||
self.merge_ops = nn.ModuleList()
|
||||
self.scale_permuted_blocks = nn.ModuleList()
|
||||
for spec in self._block_specs:
|
||||
self.merge_ops.append(
|
||||
Merge(spec, self._resample_alpha, self._filter_size_scale)
|
||||
)
|
||||
channels = int(FILTER_SIZE_MAP[spec.level] * self._filter_size_scale)
|
||||
in_channels = channels * 4 if spec.block_fn == Bottleneck else channels
|
||||
self.scale_permuted_blocks.append(
|
||||
make_res_layer(spec.block_fn,
|
||||
in_channels,
|
||||
channels,
|
||||
self._block_repeats)
|
||||
)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
if self.zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
constant_init(m.norm3, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
constant_init(m.norm2, 0)
|
||||
|
||||
def forward(self, input):
|
||||
feat = self.conv1(input)
|
||||
feat1 = self.init_block1(feat)
|
||||
feat2 = self.init_block2(feat1)
|
||||
block_feats = [feat1, feat2]
|
||||
output_feat = {}
|
||||
num_outgoing_connections = [0, 0]
|
||||
|
||||
for i, spec in enumerate(self._block_specs):
|
||||
target_feat = self.merge_ops[i]([block_feats[feat_idx] for feat_idx in spec.input_offsets])
|
||||
# Connect intermediate blocks with outdegree 0 to the output block.
|
||||
if spec.is_output:
|
||||
for j, (j_feat, j_connections) in enumerate(
|
||||
zip(block_feats, num_outgoing_connections)):
|
||||
if j_connections == 0 and j_feat.shape == target_feat.shape:
|
||||
target_feat += j_feat
|
||||
num_outgoing_connections[j] += 1
|
||||
target_feat = F.relu(target_feat, inplace=True)
|
||||
target_feat = self.scale_permuted_blocks[i](target_feat)
|
||||
block_feats.append(target_feat)
|
||||
num_outgoing_connections.append(0)
|
||||
for feat_idx in spec.input_offsets:
|
||||
num_outgoing_connections[feat_idx] += 1
|
||||
if spec.is_output:
|
||||
output_feat[spec.level] = target_feat
|
||||
|
||||
return [self.endpoint_convs[str(level)](output_feat[level]) for level in self.output_level]
|
|
@ -66,6 +66,13 @@ def define_G(opt, net_key='network_G'):
|
|||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == "ConfigurableSwitchedResidualGenerator3":
|
||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator3(trans_counts=opt_net['trans_counts'],
|
||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||
transformation_filters=opt_net['transformation_filters'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == "NestedSwitchGenerator":
|
||||
netG = ng.NestedSwitchedGenerator(switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
|
|
|
@ -33,7 +33,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_feat_resgen2_lr.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_srg3.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -162,7 +162,7 @@ def main():
|
|||
current_step = resume_state['iter']
|
||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||
else:
|
||||
current_step = -1
|
||||
current_step = 0
|
||||
start_epoch = 0
|
||||
|
||||
#### training
|
||||
|
|
|
@ -4,7 +4,7 @@ import models.archs.SwitchedResidualGenerator_arch as srg
|
|||
import models.archs.NestedSwitchGenerator as nsg
|
||||
import functools
|
||||
|
||||
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
||||
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax, srg.Interpolate]
|
||||
def install_forward_trace_hooks(module, id="base"):
|
||||
if type(module) in blacklisted_modules:
|
||||
return
|
||||
|
@ -96,15 +96,11 @@ if __name__ == "__main__":
|
|||
torch.randn(1, 3, 64, 64),
|
||||
device='cuda')
|
||||
'''
|
||||
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
|
||||
switch_filters=[16,16,16,16,16],
|
||||
switch_growths=[32,32,32,32,32],
|
||||
switch_reductions=[1,1,1,1,1],
|
||||
switch_processing_layers=[5,5,5,5,5],
|
||||
trans_counts=[8,8,8,8,8],
|
||||
trans_kernel_sizes=[3,3,3,3,3],
|
||||
trans_layers=[3,3,3,3,3],
|
||||
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator3,
|
||||
trans_counts=[8],
|
||||
trans_kernel_sizes=[3],
|
||||
trans_layers=[3],
|
||||
transformation_filters=64,
|
||||
initial_temp=10),
|
||||
torch.randn(1, 3, 64, 64),
|
||||
torch.randn(1, 3, 128, 128),
|
||||
device='cuda')
|
||||
|
|
Loading…
Reference in New Issue
Block a user