Add the "BigSwitch"

This commit is contained in:
James Betker 2020-10-13 10:11:10 -06:00
parent ca523215c6
commit 9a5d6162e9
6 changed files with 112 additions and 62 deletions

View File

@ -147,9 +147,9 @@ class SwitchWithReference(nn.Module):
def forward(self, x, mplex_ref=None, ref=None):
if self.ref_join is not None:
branch, ref_std = self.ref_join(x, ref)
return self.switch(branch, True, identity=x, att_in=(branch, mplex_ref)) + (ref_std,)
return self.switch(branch, identity=x, att_in=(branch, mplex_ref)) + (ref_std,)
else:
return self.switch(x, True, identity=x, att_in=(x, mplex_ref))
return self.switch(x, identity=x, att_in=(x, mplex_ref))
class SSGr1(SwitchModelBase):

View File

@ -4,11 +4,13 @@ from switched_conv.switched_conv import BareConvSwitch, compute_attention_specif
import torch.nn.functional as F
import functools
from collections import OrderedDict
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, SiLU
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \
SiLU, UpconvBlock, ReferenceJoinBlock
from switched_conv.switched_conv_util import save_attention_to_image_rgb
import os
from models.archs.spinenet_arch import SpineNet
import torchvision
from utils.util import checkpoint
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
# Doubles the input filter count.
@ -106,7 +108,7 @@ class ConfigurableSwitchComputer(nn.Module):
# Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element
# *must* be the actual parameter that gets fed through the network - it is assumed to be the identity.
def forward(self, x, output_attention_weights=False, identity=None, att_in=None, fixed_scale=1):
def forward(self, x, att_in=None, identity=None, output_attention_weights=True, fixed_scale=1, do_checkpointing=False):
if isinstance(x, tuple):
x1 = x[0]
else:
@ -131,13 +133,19 @@ class ConfigurableSwitchComputer(nn.Module):
x = self.pre_transform(*x)
if not isinstance(x, tuple):
x = (x,)
xformed = [t(*x) for t in self.transforms]
if do_checkpointing:
xformed = [checkpoint(t, *x) for t in self.transforms]
else:
xformed = [t(*x) for t in self.transforms]
if not isinstance(att_in, tuple):
att_in = (att_in,)
if self.feed_transforms_into_multiplexer:
att_in = att_in + (torch.stack(xformed, dim=1),)
m = self.multiplexer(*att_in)
if do_checkpointing:
m = checkpoint(self.multiplexer, *att_in)
else:
m = self.multiplexer(*att_in)
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
outputs, attention = self.switch(xformed, m, True, self.update_norm)
@ -592,15 +600,90 @@ class SwitchModelBase(nn.Module):
return val
from models.archs.spinenet_arch import make_res_layer, BasicBlock
class BigMultiplexer(nn.Module):
def __init__(self, in_nc, nf, multiplexer_channels):
super(BigMultiplexer, self).__init__()
self.spine = SpineNet(arch='96', output_level=[3], double_reduce_early=False)
self.spine_red_proc = ConvGnSilu(256, nf, kernel_size=1, activation=False, norm=False, bias=False)
self.fea_tail = ConvGnSilu(in_nc, nf, kernel_size=7, bias=True, norm=False, activation=False)
self.tail_proc = make_res_layer(BasicBlock, nf, nf, 2)
self.tail_join = ReferenceJoinBlock(nf)
# Blocks used to create the key
self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True)
# Postprocessing blocks.
self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=3, activation=True, norm=False, bias=False)
self.cbl0 = ConvGnSilu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False)
self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4)
self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False)
def forward(self, x, transformations):
s = self.spine(x)[0]
tail = self.fea_tail(x)
tail = self.tail_proc(tail)
q = F.interpolate(s, scale_factor=2, mode='bilinear')
q = self.spine_red_proc(q)
q, _ = self.tail_join(q, tail)
b, t, f, h, w = transformations.shape
k = transformations.view(b * t, f, h, w)
k = self.key_process(k)
q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1).view(b * t, f, h, w)
v = self.query_key_combine(torch.cat([q, k], dim=1))
v = self.cbl0(v)
v = self.cbl1(v)
v = self.cbl2(v)
return v.view(b, t, h, w)
class TheBigSwitch(SwitchModelBase):
def __init__(self, in_nc, nf, xforms=16, upscale=2, init_temperature=10):
super(TheBigSwitch, self).__init__(init_temperature, 10000)
self.nf = nf
self.transformation_counts = xforms
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False)
multiplx_fn = functools.partial(BigMultiplexer, in_nc, nf)
transform_fn = functools.partial(MultiConvBlock, nf, int(nf * 1.5), nf, kernel_size=3, depth=4, weight_init_factor=.1)
self.switch = ConfigurableSwitchComputer(nf, multiplx_fn,
pre_transform_block=None, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
self.switches = [self.switch]
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True)
self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True)
self.final_hr_conv2 = ConvGnLelu(nf // 2, 3, kernel_size=3, norm=False, activation=False, bias=False)
def forward(self, x, save_attentions=True):
# The attention_maps debugger outputs <x>. Save that here.
self.lr = x.detach().cpu()
# If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention
# norm should only be getting updates with new data, not recurrent generator sampling.
for sw in self.switches:
sw.set_update_attention_norm(save_attentions)
x1 = self.model_fea_conv(x)
x1, a1 = self.switch(x1, att_in=x, do_checkpointing=True)
x_out = checkpoint(self.final_lr_conv, x1)
x_out = checkpoint(self.upsample, x_out)
x_out = checkpoint(self.final_hr_conv2, x_out)
if save_attentions:
self.attentions = [a1]
return x_out,
if __name__ == '__main__':
bb = BackboneEncoder(64)
emb = QueryKeyMultiplexer(64, 10)
tbs = TheBigSwitch(3, 64)
x = torch.randn(4,3,64,64)
r = torch.randn(4,3,128,128)
xu = torch.randn(4,64,64,64)
cp = torch.zeros((4,2), dtype=torch.long)
trans = [torch.randn(4,64,64,64) for t in range(10)]
b = bb(x, r, cp)
emb(xu, b, trans)
b = tbs(x)

View File

@ -479,7 +479,7 @@ class ReferenceJoinBlock(nn.Module):
class UpconvBlock(nn.Module):
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):
super(UpconvBlock, self).__init__()
self.process = block(filters_out, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm)
self.process = block(filters_in, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode="nearest")

View File

@ -6,42 +6,7 @@ import torch.nn.functional as F
from torch.nn.init import kaiming_normal
from torchvision.models.resnet import BasicBlock, Bottleneck
from torch.nn.modules.batchnorm import _BatchNorm
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvBnRelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True):
super(ConvBnRelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
if bn:
self.bn = nn.BatchNorm2d(filters_out)
else:
self.bn = None
if relu:
self.relu = nn.ReLU()
else:
self.relu = None
# Init params.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv(x)
if self.bn:
x = self.bn(x)
if self.relu:
return self.relu(x)
else:
return x
from models.archs.arch_util import ConvGnSilu
def constant_init(module, val, bias=0):
@ -194,10 +159,10 @@ class Resample(nn.Module):
new_in_channels = int(in_channels * alpha)
if block_type == Bottleneck:
in_channels *= 4
self.squeeze_conv = ConvBnRelu(in_channels, new_in_channels, kernel_size=1)
self.squeeze_conv = ConvGnSilu(in_channels, new_in_channels, kernel_size=1)
if scale < 1:
self.downsample_conv = ConvBnRelu(new_in_channels, new_in_channels, kernel_size=3, stride=2)
self.expand_conv = ConvBnRelu(new_in_channels, out_channels, kernel_size=1, relu=False)
self.downsample_conv = ConvGnSilu(new_in_channels, new_in_channels, kernel_size=3, stride=2)
self.expand_conv = ConvGnSilu(new_in_channels, out_channels, kernel_size=1, activation=False)
def _resize(self, x):
if self.scale == 1:
@ -277,14 +242,14 @@ class SpineNet(nn.Module):
"""Build the stem network."""
# Build the first conv and maxpooling layers.
if self._early_double_reduce:
self.conv1 = ConvBnRelu(
self.conv1 = ConvGnSilu(
in_channels,
64,
kernel_size=7,
stride=2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else:
self.conv1 = ConvBnRelu(
self.conv1 = ConvGnSilu(
in_channels,
64,
kernel_size=7,
@ -308,10 +273,10 @@ class SpineNet(nn.Module):
for block_spec in self._block_specs:
if block_spec.is_output:
in_channels = int(FILTER_SIZE_MAP[block_spec.level]*self._filter_size_scale) * 4
self.endpoint_convs[str(block_spec.level)] = ConvBnRelu(in_channels,
self.endpoint_convs[str(block_spec.level)] = ConvGnSilu(in_channels,
self._endpoints_num_filters,
kernel_size=1,
relu=False)
activation=False)
def _make_scale_permuted_network(self):
self.merge_ops = nn.ModuleList()

View File

@ -109,6 +109,8 @@ def define_G(opt, net_key='network_G', scale=None):
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
elif which_model == 'ssg_teco':
netG = ssg.StackedSwitchGenerator2xTeco(nf=opt_net['nf'], xforms=opt_net['num_transforms'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
elif which_model == 'big_switch':
netG = SwitchedGen_arch.TheBigSwitch(opt_net['in_nc'], opt_net['nf'], opt_net['num_transforms'], opt_net['scale'], opt_net['temperature'])
elif which_model == "flownet2":
from models.flownet2.models import FlowNet2
ld = torch.load(opt_net['load_path'])

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_pretrain_ssgteco.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_bigswitch.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
@ -124,7 +124,7 @@ def main():
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
# torch.autograd.set_detect_anomaly(True)
#torch.autograd.set_detect_anomaly(True)
# Save the compiled opt dict to the global loaded_options variable.
util.loaded_options = opt