Add the "BigSwitch"
This commit is contained in:
parent
ca523215c6
commit
9a5d6162e9
|
@ -147,9 +147,9 @@ class SwitchWithReference(nn.Module):
|
|||
def forward(self, x, mplex_ref=None, ref=None):
|
||||
if self.ref_join is not None:
|
||||
branch, ref_std = self.ref_join(x, ref)
|
||||
return self.switch(branch, True, identity=x, att_in=(branch, mplex_ref)) + (ref_std,)
|
||||
return self.switch(branch, identity=x, att_in=(branch, mplex_ref)) + (ref_std,)
|
||||
else:
|
||||
return self.switch(x, True, identity=x, att_in=(x, mplex_ref))
|
||||
return self.switch(x, identity=x, att_in=(x, mplex_ref))
|
||||
|
||||
|
||||
class SSGr1(SwitchModelBase):
|
||||
|
|
|
@ -4,11 +4,13 @@ from switched_conv.switched_conv import BareConvSwitch, compute_attention_specif
|
|||
import torch.nn.functional as F
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, SiLU
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \
|
||||
SiLU, UpconvBlock, ReferenceJoinBlock
|
||||
from switched_conv.switched_conv_util import save_attention_to_image_rgb
|
||||
import os
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
import torchvision
|
||||
from utils.util import checkpoint
|
||||
|
||||
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
||||
# Doubles the input filter count.
|
||||
|
@ -106,7 +108,7 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
|
||||
# Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element
|
||||
# *must* be the actual parameter that gets fed through the network - it is assumed to be the identity.
|
||||
def forward(self, x, output_attention_weights=False, identity=None, att_in=None, fixed_scale=1):
|
||||
def forward(self, x, att_in=None, identity=None, output_attention_weights=True, fixed_scale=1, do_checkpointing=False):
|
||||
if isinstance(x, tuple):
|
||||
x1 = x[0]
|
||||
else:
|
||||
|
@ -131,12 +133,18 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
x = self.pre_transform(*x)
|
||||
if not isinstance(x, tuple):
|
||||
x = (x,)
|
||||
if do_checkpointing:
|
||||
xformed = [checkpoint(t, *x) for t in self.transforms]
|
||||
else:
|
||||
xformed = [t(*x) for t in self.transforms]
|
||||
|
||||
if not isinstance(att_in, tuple):
|
||||
att_in = (att_in,)
|
||||
if self.feed_transforms_into_multiplexer:
|
||||
att_in = att_in + (torch.stack(xformed, dim=1),)
|
||||
if do_checkpointing:
|
||||
m = checkpoint(self.multiplexer, *att_in)
|
||||
else:
|
||||
m = self.multiplexer(*att_in)
|
||||
|
||||
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
||||
|
@ -592,15 +600,90 @@ class SwitchModelBase(nn.Module):
|
|||
return val
|
||||
|
||||
|
||||
from models.archs.spinenet_arch import make_res_layer, BasicBlock
|
||||
class BigMultiplexer(nn.Module):
|
||||
def __init__(self, in_nc, nf, multiplexer_channels):
|
||||
super(BigMultiplexer, self).__init__()
|
||||
|
||||
self.spine = SpineNet(arch='96', output_level=[3], double_reduce_early=False)
|
||||
self.spine_red_proc = ConvGnSilu(256, nf, kernel_size=1, activation=False, norm=False, bias=False)
|
||||
self.fea_tail = ConvGnSilu(in_nc, nf, kernel_size=7, bias=True, norm=False, activation=False)
|
||||
self.tail_proc = make_res_layer(BasicBlock, nf, nf, 2)
|
||||
self.tail_join = ReferenceJoinBlock(nf)
|
||||
|
||||
# Blocks used to create the key
|
||||
self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True)
|
||||
|
||||
# Postprocessing blocks.
|
||||
self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=3, activation=True, norm=False, bias=False)
|
||||
self.cbl0 = ConvGnSilu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False)
|
||||
self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4)
|
||||
self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False)
|
||||
|
||||
def forward(self, x, transformations):
|
||||
s = self.spine(x)[0]
|
||||
tail = self.fea_tail(x)
|
||||
tail = self.tail_proc(tail)
|
||||
q = F.interpolate(s, scale_factor=2, mode='bilinear')
|
||||
q = self.spine_red_proc(q)
|
||||
q, _ = self.tail_join(q, tail)
|
||||
|
||||
b, t, f, h, w = transformations.shape
|
||||
k = transformations.view(b * t, f, h, w)
|
||||
k = self.key_process(k)
|
||||
|
||||
q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1).view(b * t, f, h, w)
|
||||
v = self.query_key_combine(torch.cat([q, k], dim=1))
|
||||
v = self.cbl0(v)
|
||||
v = self.cbl1(v)
|
||||
v = self.cbl2(v)
|
||||
|
||||
return v.view(b, t, h, w)
|
||||
|
||||
|
||||
class TheBigSwitch(SwitchModelBase):
|
||||
def __init__(self, in_nc, nf, xforms=16, upscale=2, init_temperature=10):
|
||||
super(TheBigSwitch, self).__init__(init_temperature, 10000)
|
||||
self.nf = nf
|
||||
self.transformation_counts = xforms
|
||||
|
||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False)
|
||||
|
||||
multiplx_fn = functools.partial(BigMultiplexer, in_nc, nf)
|
||||
transform_fn = functools.partial(MultiConvBlock, nf, int(nf * 1.5), nf, kernel_size=3, depth=4, weight_init_factor=.1)
|
||||
self.switch = ConfigurableSwitchComputer(nf, multiplx_fn,
|
||||
pre_transform_block=None, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
|
||||
self.switches = [self.switch]
|
||||
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True)
|
||||
self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True)
|
||||
self.final_hr_conv2 = ConvGnLelu(nf // 2, 3, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
|
||||
def forward(self, x, save_attentions=True):
|
||||
# The attention_maps debugger outputs <x>. Save that here.
|
||||
self.lr = x.detach().cpu()
|
||||
|
||||
# If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention
|
||||
# norm should only be getting updates with new data, not recurrent generator sampling.
|
||||
for sw in self.switches:
|
||||
sw.set_update_attention_norm(save_attentions)
|
||||
|
||||
x1 = self.model_fea_conv(x)
|
||||
x1, a1 = self.switch(x1, att_in=x, do_checkpointing=True)
|
||||
x_out = checkpoint(self.final_lr_conv, x1)
|
||||
x_out = checkpoint(self.upsample, x_out)
|
||||
x_out = checkpoint(self.final_hr_conv2, x_out)
|
||||
|
||||
if save_attentions:
|
||||
self.attentions = [a1]
|
||||
return x_out,
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
bb = BackboneEncoder(64)
|
||||
emb = QueryKeyMultiplexer(64, 10)
|
||||
tbs = TheBigSwitch(3, 64)
|
||||
x = torch.randn(4,3,64,64)
|
||||
r = torch.randn(4,3,128,128)
|
||||
xu = torch.randn(4,64,64,64)
|
||||
cp = torch.zeros((4,2), dtype=torch.long)
|
||||
|
||||
trans = [torch.randn(4,64,64,64) for t in range(10)]
|
||||
|
||||
b = bb(x, r, cp)
|
||||
emb(xu, b, trans)
|
||||
b = tbs(x)
|
|
@ -479,7 +479,7 @@ class ReferenceJoinBlock(nn.Module):
|
|||
class UpconvBlock(nn.Module):
|
||||
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):
|
||||
super(UpconvBlock, self).__init__()
|
||||
self.process = block(filters_out, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm)
|
||||
self.process = block(filters_in, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
|
|
|
@ -6,42 +6,7 @@ import torch.nn.functional as F
|
|||
from torch.nn.init import kaiming_normal
|
||||
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
|
||||
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
class ConvBnRelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True):
|
||||
super(ConvBnRelu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
if bn:
|
||||
self.bn = nn.BatchNorm2d(filters_out)
|
||||
else:
|
||||
self.bn = None
|
||||
if relu:
|
||||
self.relu = nn.ReLU()
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn:
|
||||
x = self.bn(x)
|
||||
if self.relu:
|
||||
return self.relu(x)
|
||||
else:
|
||||
return x
|
||||
from models.archs.arch_util import ConvGnSilu
|
||||
|
||||
|
||||
def constant_init(module, val, bias=0):
|
||||
|
@ -194,10 +159,10 @@ class Resample(nn.Module):
|
|||
new_in_channels = int(in_channels * alpha)
|
||||
if block_type == Bottleneck:
|
||||
in_channels *= 4
|
||||
self.squeeze_conv = ConvBnRelu(in_channels, new_in_channels, kernel_size=1)
|
||||
self.squeeze_conv = ConvGnSilu(in_channels, new_in_channels, kernel_size=1)
|
||||
if scale < 1:
|
||||
self.downsample_conv = ConvBnRelu(new_in_channels, new_in_channels, kernel_size=3, stride=2)
|
||||
self.expand_conv = ConvBnRelu(new_in_channels, out_channels, kernel_size=1, relu=False)
|
||||
self.downsample_conv = ConvGnSilu(new_in_channels, new_in_channels, kernel_size=3, stride=2)
|
||||
self.expand_conv = ConvGnSilu(new_in_channels, out_channels, kernel_size=1, activation=False)
|
||||
|
||||
def _resize(self, x):
|
||||
if self.scale == 1:
|
||||
|
@ -277,14 +242,14 @@ class SpineNet(nn.Module):
|
|||
"""Build the stem network."""
|
||||
# Build the first conv and maxpooling layers.
|
||||
if self._early_double_reduce:
|
||||
self.conv1 = ConvBnRelu(
|
||||
self.conv1 = ConvGnSilu(
|
||||
in_channels,
|
||||
64,
|
||||
kernel_size=7,
|
||||
stride=2)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
else:
|
||||
self.conv1 = ConvBnRelu(
|
||||
self.conv1 = ConvGnSilu(
|
||||
in_channels,
|
||||
64,
|
||||
kernel_size=7,
|
||||
|
@ -308,10 +273,10 @@ class SpineNet(nn.Module):
|
|||
for block_spec in self._block_specs:
|
||||
if block_spec.is_output:
|
||||
in_channels = int(FILTER_SIZE_MAP[block_spec.level]*self._filter_size_scale) * 4
|
||||
self.endpoint_convs[str(block_spec.level)] = ConvBnRelu(in_channels,
|
||||
self.endpoint_convs[str(block_spec.level)] = ConvGnSilu(in_channels,
|
||||
self._endpoints_num_filters,
|
||||
kernel_size=1,
|
||||
relu=False)
|
||||
activation=False)
|
||||
|
||||
def _make_scale_permuted_network(self):
|
||||
self.merge_ops = nn.ModuleList()
|
||||
|
|
|
@ -109,6 +109,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||
elif which_model == 'ssg_teco':
|
||||
netG = ssg.StackedSwitchGenerator2xTeco(nf=opt_net['nf'], xforms=opt_net['num_transforms'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||
elif which_model == 'big_switch':
|
||||
netG = SwitchedGen_arch.TheBigSwitch(opt_net['in_nc'], opt_net['nf'], opt_net['num_transforms'], opt_net['scale'], opt_net['temperature'])
|
||||
elif which_model == "flownet2":
|
||||
from models.flownet2.models import FlowNet2
|
||||
ld = torch.load(opt_net['load_path'])
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_pretrain_ssgteco.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_bigswitch.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
@ -124,7 +124,7 @@ def main():
|
|||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
#torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
# Save the compiled opt dict to the global loaded_options variable.
|
||||
util.loaded_options = opt
|
||||
|
|
Loading…
Reference in New Issue
Block a user