diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index 6df031b6..1a693792 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -147,9 +147,9 @@ class SwitchWithReference(nn.Module): def forward(self, x, mplex_ref=None, ref=None): if self.ref_join is not None: branch, ref_std = self.ref_join(x, ref) - return self.switch(branch, True, identity=x, att_in=(branch, mplex_ref)) + (ref_std,) + return self.switch(branch, identity=x, att_in=(branch, mplex_ref)) + (ref_std,) else: - return self.switch(x, True, identity=x, att_in=(x, mplex_ref)) + return self.switch(x, identity=x, att_in=(x, mplex_ref)) class SSGr1(SwitchModelBase): diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index b79fd55f..feb6168d 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -4,11 +4,13 @@ from switched_conv.switched_conv import BareConvSwitch, compute_attention_specif import torch.nn.functional as F import functools from collections import OrderedDict -from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, SiLU +from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, \ + SiLU, UpconvBlock, ReferenceJoinBlock from switched_conv.switched_conv_util import save_attention_to_image_rgb import os from models.archs.spinenet_arch import SpineNet import torchvision +from utils.util import checkpoint # VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation # Doubles the input filter count. @@ -106,7 +108,7 @@ class ConfigurableSwitchComputer(nn.Module): # Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element # *must* be the actual parameter that gets fed through the network - it is assumed to be the identity. - def forward(self, x, output_attention_weights=False, identity=None, att_in=None, fixed_scale=1): + def forward(self, x, att_in=None, identity=None, output_attention_weights=True, fixed_scale=1, do_checkpointing=False): if isinstance(x, tuple): x1 = x[0] else: @@ -131,13 +133,19 @@ class ConfigurableSwitchComputer(nn.Module): x = self.pre_transform(*x) if not isinstance(x, tuple): x = (x,) - xformed = [t(*x) for t in self.transforms] + if do_checkpointing: + xformed = [checkpoint(t, *x) for t in self.transforms] + else: + xformed = [t(*x) for t in self.transforms] if not isinstance(att_in, tuple): att_in = (att_in,) if self.feed_transforms_into_multiplexer: att_in = att_in + (torch.stack(xformed, dim=1),) - m = self.multiplexer(*att_in) + if do_checkpointing: + m = checkpoint(self.multiplexer, *att_in) + else: + m = self.multiplexer(*att_in) # It is assumed that [xformed] and [m] are collapsed into tensors at this point. outputs, attention = self.switch(xformed, m, True, self.update_norm) @@ -592,15 +600,90 @@ class SwitchModelBase(nn.Module): return val +from models.archs.spinenet_arch import make_res_layer, BasicBlock +class BigMultiplexer(nn.Module): + def __init__(self, in_nc, nf, multiplexer_channels): + super(BigMultiplexer, self).__init__() + + self.spine = SpineNet(arch='96', output_level=[3], double_reduce_early=False) + self.spine_red_proc = ConvGnSilu(256, nf, kernel_size=1, activation=False, norm=False, bias=False) + self.fea_tail = ConvGnSilu(in_nc, nf, kernel_size=7, bias=True, norm=False, activation=False) + self.tail_proc = make_res_layer(BasicBlock, nf, nf, 2) + self.tail_join = ReferenceJoinBlock(nf) + + # Blocks used to create the key + self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True) + + # Postprocessing blocks. + self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=3, activation=True, norm=False, bias=False) + self.cbl0 = ConvGnSilu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False) + self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4) + self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False) + + def forward(self, x, transformations): + s = self.spine(x)[0] + tail = self.fea_tail(x) + tail = self.tail_proc(tail) + q = F.interpolate(s, scale_factor=2, mode='bilinear') + q = self.spine_red_proc(q) + q, _ = self.tail_join(q, tail) + + b, t, f, h, w = transformations.shape + k = transformations.view(b * t, f, h, w) + k = self.key_process(k) + + q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1).view(b * t, f, h, w) + v = self.query_key_combine(torch.cat([q, k], dim=1)) + v = self.cbl0(v) + v = self.cbl1(v) + v = self.cbl2(v) + + return v.view(b, t, h, w) + + +class TheBigSwitch(SwitchModelBase): + def __init__(self, in_nc, nf, xforms=16, upscale=2, init_temperature=10): + super(TheBigSwitch, self).__init__(init_temperature, 10000) + self.nf = nf + self.transformation_counts = xforms + + self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) + + multiplx_fn = functools.partial(BigMultiplexer, in_nc, nf) + transform_fn = functools.partial(MultiConvBlock, nf, int(nf * 1.5), nf, kernel_size=3, depth=4, weight_init_factor=.1) + self.switch = ConfigurableSwitchComputer(nf, multiplx_fn, + pre_transform_block=None, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) + self.switches = [self.switch] + + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) + self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) + self.final_hr_conv2 = ConvGnLelu(nf // 2, 3, kernel_size=3, norm=False, activation=False, bias=False) + + def forward(self, x, save_attentions=True): + # The attention_maps debugger outputs . Save that here. + self.lr = x.detach().cpu() + + # If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention + # norm should only be getting updates with new data, not recurrent generator sampling. + for sw in self.switches: + sw.set_update_attention_norm(save_attentions) + + x1 = self.model_fea_conv(x) + x1, a1 = self.switch(x1, att_in=x, do_checkpointing=True) + x_out = checkpoint(self.final_lr_conv, x1) + x_out = checkpoint(self.upsample, x_out) + x_out = checkpoint(self.final_hr_conv2, x_out) + + if save_attentions: + self.attentions = [a1] + return x_out, + + if __name__ == '__main__': - bb = BackboneEncoder(64) - emb = QueryKeyMultiplexer(64, 10) + tbs = TheBigSwitch(3, 64) x = torch.randn(4,3,64,64) - r = torch.randn(4,3,128,128) - xu = torch.randn(4,64,64,64) - cp = torch.zeros((4,2), dtype=torch.long) - - trans = [torch.randn(4,64,64,64) for t in range(10)] - - b = bb(x, r, cp) - emb(xu, b, trans) \ No newline at end of file + b = tbs(x) \ No newline at end of file diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 6fefa9f4..418d26bb 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -479,7 +479,7 @@ class ReferenceJoinBlock(nn.Module): class UpconvBlock(nn.Module): def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False): super(UpconvBlock, self).__init__() - self.process = block(filters_out, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm) + self.process = block(filters_in, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm) def forward(self, x): x = F.interpolate(x, scale_factor=2, mode="nearest") diff --git a/codes/models/archs/spinenet_arch.py b/codes/models/archs/spinenet_arch.py index 1e5a192b..99f95fa7 100644 --- a/codes/models/archs/spinenet_arch.py +++ b/codes/models/archs/spinenet_arch.py @@ -6,42 +6,7 @@ import torch.nn.functional as F from torch.nn.init import kaiming_normal from torchvision.models.resnet import BasicBlock, Bottleneck -from torch.nn.modules.batchnorm import _BatchNorm - - -''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard - kernel sizes. ''' -class ConvBnRelu(nn.Module): - def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True): - super(ConvBnRelu, self).__init__() - padding_map = {1: 0, 3: 1, 5: 2, 7: 3} - assert kernel_size in padding_map.keys() - self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) - if bn: - self.bn = nn.BatchNorm2d(filters_out) - else: - self.bn = None - if relu: - self.relu = nn.ReLU() - else: - self.relu = None - - # Init params. - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def forward(self, x): - x = self.conv(x) - if self.bn: - x = self.bn(x) - if self.relu: - return self.relu(x) - else: - return x +from models.archs.arch_util import ConvGnSilu def constant_init(module, val, bias=0): @@ -194,10 +159,10 @@ class Resample(nn.Module): new_in_channels = int(in_channels * alpha) if block_type == Bottleneck: in_channels *= 4 - self.squeeze_conv = ConvBnRelu(in_channels, new_in_channels, kernel_size=1) + self.squeeze_conv = ConvGnSilu(in_channels, new_in_channels, kernel_size=1) if scale < 1: - self.downsample_conv = ConvBnRelu(new_in_channels, new_in_channels, kernel_size=3, stride=2) - self.expand_conv = ConvBnRelu(new_in_channels, out_channels, kernel_size=1, relu=False) + self.downsample_conv = ConvGnSilu(new_in_channels, new_in_channels, kernel_size=3, stride=2) + self.expand_conv = ConvGnSilu(new_in_channels, out_channels, kernel_size=1, activation=False) def _resize(self, x): if self.scale == 1: @@ -277,14 +242,14 @@ class SpineNet(nn.Module): """Build the stem network.""" # Build the first conv and maxpooling layers. if self._early_double_reduce: - self.conv1 = ConvBnRelu( + self.conv1 = ConvGnSilu( in_channels, 64, kernel_size=7, stride=2) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: - self.conv1 = ConvBnRelu( + self.conv1 = ConvGnSilu( in_channels, 64, kernel_size=7, @@ -308,10 +273,10 @@ class SpineNet(nn.Module): for block_spec in self._block_specs: if block_spec.is_output: in_channels = int(FILTER_SIZE_MAP[block_spec.level]*self._filter_size_scale) * 4 - self.endpoint_convs[str(block_spec.level)] = ConvBnRelu(in_channels, + self.endpoint_convs[str(block_spec.level)] = ConvGnSilu(in_channels, self._endpoints_num_filters, kernel_size=1, - relu=False) + activation=False) def _make_scale_permuted_network(self): self.merge_ops = nn.ModuleList() diff --git a/codes/models/networks.py b/codes/models/networks.py index c3737c09..987ddb19 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -109,6 +109,8 @@ def define_G(opt, net_key='network_G', scale=None): init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == 'ssg_teco': netG = ssg.StackedSwitchGenerator2xTeco(nf=opt_net['nf'], xforms=opt_net['num_transforms'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) + elif which_model == 'big_switch': + netG = SwitchedGen_arch.TheBigSwitch(opt_net['in_nc'], opt_net['nf'], opt_net['num_transforms'], opt_net['scale'], opt_net['temperature']) elif which_model == "flownet2": from models.flownet2.models import FlowNet2 ld = torch.load(opt_net['load_path']) diff --git a/codes/train2.py b/codes/train2.py index 404bdfd4..9a196009 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_pretrain_ssgteco.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_bigswitch.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() @@ -124,7 +124,7 @@ def main(): torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True - # torch.autograd.set_detect_anomaly(True) + #torch.autograd.set_detect_anomaly(True) # Save the compiled opt dict to the global loaded_options variable. util.loaded_options = opt