From fc623d4b5a2071a238565bbee0ae7b64159df9ca Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Fri, 23 Apr 2021 17:16:46 -0600
Subject: [PATCH] Add segformer model. Start work on BYOL adaptation that will
 support training it.

---
 .../models/byol/byol_for_semantic_chaining.py | 327 ++++++++++++++++++
 codes/models/segformer/backbone.py            | 132 +++++++
 codes/models/segformer/segformer.py           |  89 +++++
 3 files changed, 548 insertions(+)
 create mode 100644 codes/models/byol/byol_for_semantic_chaining.py
 create mode 100644 codes/models/segformer/backbone.py
 create mode 100644 codes/models/segformer/segformer.py

diff --git a/codes/models/byol/byol_for_semantic_chaining.py b/codes/models/byol/byol_for_semantic_chaining.py
new file mode 100644
index 00000000..b92cbc28
--- /dev/null
+++ b/codes/models/byol/byol_for_semantic_chaining.py
@@ -0,0 +1,327 @@
+import copy
+import os
+import random
+from functools import wraps
+import kornia.augmentation as augs
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from PIL import Image
+from kornia import filters, apply_hflip
+from torch import nn
+from torchvision.transforms import ToTensor
+
+from data.byol_attachment import RandomApply
+from trainer.networks import register_model, create_model
+from utils.util import checkpoint, opt_get
+
+
+def default(val, def_val):
+    return def_val if val is None else val
+
+
+def flatten(t):
+    return t.reshape(t.shape[0], -1)
+
+
+def singleton(cache_key):
+    def inner_fn(fn):
+        @wraps(fn)
+        def wrapper(self, *args, **kwargs):
+            instance = getattr(self, cache_key)
+            if instance is not None:
+                return instance
+
+            instance = fn(self, *args, **kwargs)
+            setattr(self, cache_key, instance)
+            return instance
+
+        return wrapper
+
+    return inner_fn
+
+
+def get_module_device(module):
+    return next(module.parameters()).device
+
+
+def set_requires_grad(model, val):
+    for p in model.parameters():
+        p.requires_grad = val
+
+
+# loss fn
+def loss_fn(x, y):
+    x = F.normalize(x, dim=-1, p=2)
+    y = F.normalize(y, dim=-1, p=2)
+    return 2 - 2 * (x * y).sum(dim=-1)
+
+
+# exponential moving average
+class EMA():
+    def __init__(self, beta):
+        super().__init__()
+        self.beta = beta
+
+    def update_average(self, old, new):
+        if old is None:
+            return new
+        return old * self.beta + (1 - self.beta) * new
+
+
+def update_moving_average(ema_updater, ma_model, current_model):
+    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
+        old_weight, up_weight = ma_params.data, current_params.data
+        ma_params.data = ema_updater.update_average(old_weight, up_weight)
+
+
+# MLP class for projector and predictor
+class MLP(nn.Module):
+    def __init__(self, dim, projection_size, hidden_size=4096):
+        super().__init__()
+        self.net = nn.Sequential(
+            nn.Linear(dim, hidden_size),
+            nn.BatchNorm1d(hidden_size),
+            nn.ReLU(inplace=True),
+            nn.Linear(hidden_size, projection_size)
+        )
+
+    def forward(self, x):
+        x = flatten(x)
+        return self.net(x)
+
+
+# A wrapper class for training against networks that do not collapse into a small-dimensioned latent.
+class StructuralMLP(nn.Module):
+    def __init__(self, dim, projection_size, hidden_size=4096):
+        super().__init__()
+        b, c, h, w = dim
+        flattened_dim = c * h // 4 * w // 4
+        self.net = nn.Sequential(
+            nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2),
+            nn.BatchNorm2d(c),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2),
+            nn.BatchNorm2d(c),
+            nn.ReLU(inplace=True),
+            nn.Flatten(),
+            nn.Linear(flattened_dim, hidden_size),
+            nn.BatchNorm1d(hidden_size),
+            nn.ReLU(inplace=True),
+            nn.Linear(hidden_size, projection_size)
+        )
+
+    def forward(self, x):
+        return self.net(x)
+
+
+# a wrapper class for the base neural network
+# will manage the interception of the hidden layer output
+# and pipe it into the projecter and predictor nets
+class NetWrapper(nn.Module):
+    def __init__(self, net, projection_size, projection_hidden_size, layer=-2, use_structural_mlp=False):
+        super().__init__()
+        self.net = net
+        self.layer = layer
+
+        self.projector = None
+        self.projection_size = projection_size
+        self.projection_hidden_size = projection_hidden_size
+        self.structural_mlp = use_structural_mlp
+
+        self.hidden = None
+        self.hook_registered = False
+
+    def _find_layer(self):
+        if type(self.layer) == str:
+            modules = dict([*self.net.named_modules()])
+            return modules.get(self.layer, None)
+        elif type(self.layer) == int:
+            children = [*self.net.children()]
+            return children[self.layer]
+        return None
+
+    def _hook(self, _, __, output):
+        self.hidden = output
+
+    def _register_hook(self):
+        layer = self._find_layer()
+        assert layer is not None, f'hidden layer ({self.layer}) not found'
+        handle = layer.register_forward_hook(self._hook)
+        self.hook_registered = True
+
+    @singleton('projector')
+    def _get_projector(self, hidden):
+        if self.structural_mlp:
+            projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size)
+        else:
+            _, dim = hidden.flatten(1,-1).shape
+            projector = MLP(dim, self.projection_size, self.projection_hidden_size)
+        return projector.to(hidden)
+
+    def get_representation(self, x):
+        if self.layer == -1:
+            return self.net(x)
+
+        if not self.hook_registered:
+            self._register_hook()
+
+        unused = self.net(x)
+        hidden = self.hidden
+        self.hidden = None
+        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
+        return hidden
+
+    def forward(self, x):
+        representation = self.get_representation(x)
+        projector = self._get_projector(representation)
+        projection = checkpoint(projector, representation)
+        return projection
+
+
+class BYOL(nn.Module):
+    def __init__(
+            self,
+            net,
+            image_size,
+            hidden_layer=-2,
+            projection_size=256,
+            projection_hidden_size=4096,
+            moving_average_decay=0.99,
+            use_momentum=True,
+            structural_mlp=False,
+            do_augmentation=False  # In DLAS this was intended to be done at the dataset level. For massive batch sizes
+                                   # this can overwhelm the CPU though, and it becomes desirable to do the augmentations
+                                   # on the GPU again.
+    ):
+        super().__init__()
+
+        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
+                                         use_structural_mlp=structural_mlp)
+
+        self.do_aug = do_augmentation
+        if self.do_aug:
+            augmentations = [ \
+                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
+                augs.RandomGrayscale(p=0.2),
+                augs.RandomHorizontalFlip(),
+                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
+                augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
+            self.aug = nn.Sequential(*augmentations)
+        self.use_momentum = use_momentum
+        self.target_encoder = None
+        self.target_ema_updater = EMA(moving_average_decay)
+
+        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
+
+        # get device of network and make wrapper same device
+        device = get_module_device(net)
+        self.to(device)
+
+        # send a mock image tensor to instantiate singleton parameters
+        self.forward(torch.randn(2, 3, image_size, image_size, device=device),
+                     torch.randn(2, 3, image_size, image_size, device=device))
+
+    @singleton('target_encoder')
+    def _get_target_encoder(self):
+        target_encoder = copy.deepcopy(self.online_encoder)
+        set_requires_grad(target_encoder, False)
+        for p in target_encoder.parameters():
+            p.DO_NOT_TRAIN = True
+        return target_encoder
+
+    def reset_moving_average(self):
+        del self.target_encoder
+        self.target_encoder = None
+
+    def update_for_step(self, step, __):
+        assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
+        assert self.target_encoder is not None, 'target encoder has not been created yet'
+        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
+
+    def get_debug_values(self, step, __):
+        # In the BYOL paper, this is made to increase over time. Not yet implemented, but still logging the value.
+        return {'target_ema_beta': self.target_ema_updater.beta}
+
+    def visual_dbg(self, step, path):
+        if self.do_aug:
+            torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
+            torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
+
+    def forward(self, image_one, image_two):
+        if self.do_aug:
+            image_one = self.aug(image_one)
+            image_two = self.aug(image_two)
+            # Keep copies on hand for visual_dbg.
+            self.im1 = image_one.detach().copy()
+            self.im2 = image_two.detach().copy()
+
+        online_proj_one = self.online_encoder(image_one)
+        online_proj_two = self.online_encoder(image_two)
+
+        online_pred_one = self.online_predictor(online_proj_one)
+        online_pred_two = self.online_predictor(online_proj_two)
+
+        with torch.no_grad():
+            target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
+            target_proj_one = target_encoder(image_one).detach()
+            target_proj_two = target_encoder(image_two).detach()
+
+        loss_one = loss_fn(online_pred_one, target_proj_two.detach())
+        loss_two = loss_fn(online_pred_two, target_proj_one.detach())
+
+        loss = loss_one + loss_two
+        return loss.mean()
+
+
+class PointwiseAugmentor(nn.Module):
+    def __init__(self, img_size=224):
+        super().__init__()
+        self.jitter = RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8)
+        self.gray = augs.RandomGrayscale(p=0.2)
+        self.blur = RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
+        self.rrc = augs.RandomResizedCrop((img_size, img_size))
+
+    # Given a point in the *destination* image, returns the same point in the source image, given the kornia RRC params.
+    def reverse_rrc(self, dest_point, params):
+        dh, dw = params['dst'][:,2,1]-params['dst'][:,0,1], params['dst'][:,2,0] - params['dst'][:,0,0]
+        sh, sw = params['src'][:,2,1]-params['src'][:,0,1], params['src'][:,2,0] - params['src'][:,0,0]
+        scale_h, scale_w = sh.float() / dh.float(), sw.float() / dw.float()
+        t, l = dest_point
+        t = (t.float() * scale_h).int()
+        l = (l.float() * scale_w).int()
+        return t + params['src'][:,0,1], l + params['src'][:,0,0]
+
+    def reverse_horizontal_flip(self, pt, input):
+        t, l = pt
+        center = input.shape[-1] // 2
+        return t, 2 * center - l
+
+    def forward(self, x, points):
+        d = self.jitter(x)
+        d = self.gray(d)
+        will_flip = random.random() > .5
+        if will_flip:
+            d = apply_hflip(d)
+        d = self.blur(d)
+        params = self.rrc.generate_parameters(d.shape)
+        d = self.rrc(d, params=params)
+
+        rev = self.reverse_rrc(points, params)
+        if will_flip:
+            rev = self.reverse_horizontal_flip(rev, x)
+
+if __name__ == '__main__':
+    p = PointwiseAugmentor(256)
+    t = ToTensor()(Image.open('E:\\4k6k\\datasets\\ns_images\\imagesets\\000001_152761.jpg')).unsqueeze(0).repeat(8,1,1,1)
+    points = (torch.randint(0,224,(t.shape[0],)),torch.randint(0,224,(t.shape[0],)))
+    p(t, points)
+
+
+@register_model
+def register_byol(opt_net, opt):
+    subnet = create_model(opt, opt_net['subnet'])
+    return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
+                structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
+                do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
\ No newline at end of file
diff --git a/codes/models/segformer/backbone.py b/codes/models/segformer/backbone.py
new file mode 100644
index 00000000..0c2a06f1
--- /dev/null
+++ b/codes/models/segformer/backbone.py
@@ -0,0 +1,132 @@
+# A direct copy of torchvision's resnet.py modified to support gradient checkpointing.
+
+import torch
+import torch.nn as nn
+from torchvision.models.resnet import BasicBlock, Bottleneck
+from torchvision.models.utils import load_state_dict_from_url
+import torchvision
+
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+           'wide_resnet50_2', 'wide_resnet101_2']
+
+from trainer.networks import register_model
+from utils.util import checkpoint
+
+model_urls = {
+    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+class Backbone(torchvision.models.resnet.ResNet):
+
+    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
+                 norm_layer=None):
+        super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
+                         replace_stride_with_dilation, norm_layer)
+        del self.fc
+        del self.avgpool
+
+    def _forward_impl(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        l1 = checkpoint(self.layer1, x)
+        l2 = checkpoint(self.layer2, l1)
+        l3 = checkpoint(self.layer3, l2)
+        l4 = checkpoint(self.layer4, l3)
+
+        return l1, l2, l3, l4
+
+    def forward(self, x):
+        return self._forward_impl(x)
+
+
+def _backbone(arch, block, layers, pretrained, progress, **kwargs):
+    model = Backbone(block, layers, **kwargs)
+    if pretrained:
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+    return model
+
+
+def backbone18(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-18 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _backbone('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+                   **kwargs)
+
+
+def backbone34(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-34 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _backbone('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+                   **kwargs)
+
+
+def backbone50(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-50 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _backbone('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+                   **kwargs)
+
+
+def backbone101(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-101 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _backbone('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+                   **kwargs)
+
+
+def backbone152(pretrained=False, progress=True, **kwargs):
+    r"""ResNet-152 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+                   **kwargs)
+
+
+@register_model
+def register_resnet50(opt_net, opt):
+    model = resnet50(pretrained=opt_net['pretrained'])
+    if opt_net['custom_head_logits']:
+        model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits'])
+    return model
+
diff --git a/codes/models/segformer/segformer.py b/codes/models/segformer/segformer.py
new file mode 100644
index 00000000..b5a061c9
--- /dev/null
+++ b/codes/models/segformer/segformer.py
@@ -0,0 +1,89 @@
+import math
+
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+
+from models.segformer.backbone import backbone50
+
+
+class DilatorModule(nn.Module):
+    def __init__(self, input_channels, output_channels, max_dilation):
+        super().__init__()
+        self.max_dilation = max_dilation
+        self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, dilation=1, bias=True)
+        if max_dilation > 1:
+            self.bn = nn.BatchNorm2d(input_channels)
+            self.relu = nn.ReLU()
+            self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, dilation=max_dilation, bias=True)
+        self.dense = nn.Linear(input_channels, output_channels, bias=True)
+
+    def forward(self, inp, loc):
+        x = self.conv1(inp)
+        if self.max_dilation > 1:
+            x = self.bn(self.relu(x))
+            x = self.conv2(x)
+
+        # This can be made (possibly substantially) more efficient by only computing these convolutions across a subset of the image. Possibly.
+        i, j = loc
+        x = x[:,:,i,j]
+        return self.dense(x)
+
+
+# Grabbed from torch examples: https://github.com/pytorch/examples/tree/master/https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65:7
+class PositionalEncoding(nn.Module):
+    def __init__(self, d_model, max_len=5000):
+        super(PositionalEncoding, self).__init__()
+
+        pe = torch.zeros(max_len, d_model)
+        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0).transpose(0, 1)
+        self.register_buffer('pe', pe)
+
+    def forward(self, x):
+        x = x + self.pe[:x.size(0), :]
+        return x
+
+
+class Segformer(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.backbone = backbone50()
+        backbone_channels = [256, 512, 1024, 2048]
+        dilations = [[1,2,3,4],[1,2,3],[1,2],[1]]
+        final_latent_channels = 2048
+        dilators = []
+        for ic, dis in zip(backbone_channels, dilations):
+            layer_dilators = []
+            for di in dis:
+                layer_dilators.append(DilatorModule(ic, final_latent_channels, di))
+            dilators.append(nn.ModuleList(layer_dilators))
+        self.dilators = nn.ModuleList(dilators)
+
+        self.token_position_encoder = PositionalEncoding(final_latent_channels, max_len=10)
+        self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(16)])
+
+    def forward(self, x, pos):
+        layers = self.backbone(x)
+        set = []
+        i, j = pos[0] // 4, pos[1] // 4
+        for layer_out, dilator in zip(layers, self.dilators):
+            for subdilator in dilator:
+                set.append(subdilator(layer_out, (i, j)))
+            i, j = i // 2, j // 2
+
+        # The torch transformer expects the set dimension to be 0.
+        set = torch.stack(set, dim=0)
+        set = self.token_position_encoder(set)
+        set = self.transformer_layers(set)
+        return set
+
+
+if __name__ == '__main__':
+    model = Segformer().to('cuda')
+    for j in tqdm(range(1000)):
+        test_tensor = torch.randn(64,3,224,224).cuda()
+        model(test_tensor, (43, 73))
\ No newline at end of file