From fc623d4b5a2071a238565bbee0ae7b64159df9ca Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 23 Apr 2021 17:16:46 -0600 Subject: [PATCH] Add segformer model. Start work on BYOL adaptation that will support training it. --- .../models/byol/byol_for_semantic_chaining.py | 327 ++++++++++++++++++ codes/models/segformer/backbone.py | 132 +++++++ codes/models/segformer/segformer.py | 89 +++++ 3 files changed, 548 insertions(+) create mode 100644 codes/models/byol/byol_for_semantic_chaining.py create mode 100644 codes/models/segformer/backbone.py create mode 100644 codes/models/segformer/segformer.py diff --git a/codes/models/byol/byol_for_semantic_chaining.py b/codes/models/byol/byol_for_semantic_chaining.py new file mode 100644 index 00000000..b92cbc28 --- /dev/null +++ b/codes/models/byol/byol_for_semantic_chaining.py @@ -0,0 +1,327 @@ +import copy +import os +import random +from functools import wraps +import kornia.augmentation as augs + +import torch +import torch.nn.functional as F +import torchvision +from PIL import Image +from kornia import filters, apply_hflip +from torch import nn +from torchvision.transforms import ToTensor + +from data.byol_attachment import RandomApply +from trainer.networks import register_model, create_model +from utils.util import checkpoint, opt_get + + +def default(val, def_val): + return def_val if val is None else val + + +def flatten(t): + return t.reshape(t.shape[0], -1) + + +def singleton(cache_key): + def inner_fn(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + instance = getattr(self, cache_key) + if instance is not None: + return instance + + instance = fn(self, *args, **kwargs) + setattr(self, cache_key, instance) + return instance + + return wrapper + + return inner_fn + + +def get_module_device(module): + return next(module.parameters()).device + + +def set_requires_grad(model, val): + for p in model.parameters(): + p.requires_grad = val + + +# loss fn +def loss_fn(x, y): + x = F.normalize(x, dim=-1, p=2) + y = F.normalize(y, dim=-1, p=2) + return 2 - 2 * (x * y).sum(dim=-1) + + +# exponential moving average +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +def update_moving_average(ema_updater, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = ema_updater.update_average(old_weight, up_weight) + + +# MLP class for projector and predictor +class MLP(nn.Module): + def __init__(self, dim, projection_size, hidden_size=4096): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, projection_size) + ) + + def forward(self, x): + x = flatten(x) + return self.net(x) + + +# A wrapper class for training against networks that do not collapse into a small-dimensioned latent. +class StructuralMLP(nn.Module): + def __init__(self, dim, projection_size, hidden_size=4096): + super().__init__() + b, c, h, w = dim + flattened_dim = c * h // 4 * w // 4 + self.net = nn.Sequential( + nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2), + nn.BatchNorm2d(c), + nn.ReLU(inplace=True), + nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2), + nn.BatchNorm2d(c), + nn.ReLU(inplace=True), + nn.Flatten(), + nn.Linear(flattened_dim, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, projection_size) + ) + + def forward(self, x): + return self.net(x) + + +# a wrapper class for the base neural network +# will manage the interception of the hidden layer output +# and pipe it into the projecter and predictor nets +class NetWrapper(nn.Module): + def __init__(self, net, projection_size, projection_hidden_size, layer=-2, use_structural_mlp=False): + super().__init__() + self.net = net + self.layer = layer + + self.projector = None + self.projection_size = projection_size + self.projection_hidden_size = projection_hidden_size + self.structural_mlp = use_structural_mlp + + self.hidden = None + self.hook_registered = False + + def _find_layer(self): + if type(self.layer) == str: + modules = dict([*self.net.named_modules()]) + return modules.get(self.layer, None) + elif type(self.layer) == int: + children = [*self.net.children()] + return children[self.layer] + return None + + def _hook(self, _, __, output): + self.hidden = output + + def _register_hook(self): + layer = self._find_layer() + assert layer is not None, f'hidden layer ({self.layer}) not found' + handle = layer.register_forward_hook(self._hook) + self.hook_registered = True + + @singleton('projector') + def _get_projector(self, hidden): + if self.structural_mlp: + projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size) + else: + _, dim = hidden.flatten(1,-1).shape + projector = MLP(dim, self.projection_size, self.projection_hidden_size) + return projector.to(hidden) + + def get_representation(self, x): + if self.layer == -1: + return self.net(x) + + if not self.hook_registered: + self._register_hook() + + unused = self.net(x) + hidden = self.hidden + self.hidden = None + assert hidden is not None, f'hidden layer {self.layer} never emitted an output' + return hidden + + def forward(self, x): + representation = self.get_representation(x) + projector = self._get_projector(representation) + projection = checkpoint(projector, representation) + return projection + + +class BYOL(nn.Module): + def __init__( + self, + net, + image_size, + hidden_layer=-2, + projection_size=256, + projection_hidden_size=4096, + moving_average_decay=0.99, + use_momentum=True, + structural_mlp=False, + do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes + # this can overwhelm the CPU though, and it becomes desirable to do the augmentations + # on the GPU again. + ): + super().__init__() + + self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, + use_structural_mlp=structural_mlp) + + self.do_aug = do_augmentation + if self.do_aug: + augmentations = [ \ + RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), + augs.RandomGrayscale(p=0.2), + augs.RandomHorizontalFlip(), + RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), + augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] + self.aug = nn.Sequential(*augmentations) + self.use_momentum = use_momentum + self.target_encoder = None + self.target_ema_updater = EMA(moving_average_decay) + + self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) + + # get device of network and make wrapper same device + device = get_module_device(net) + self.to(device) + + # send a mock image tensor to instantiate singleton parameters + self.forward(torch.randn(2, 3, image_size, image_size, device=device), + torch.randn(2, 3, image_size, image_size, device=device)) + + @singleton('target_encoder') + def _get_target_encoder(self): + target_encoder = copy.deepcopy(self.online_encoder) + set_requires_grad(target_encoder, False) + for p in target_encoder.parameters(): + p.DO_NOT_TRAIN = True + return target_encoder + + def reset_moving_average(self): + del self.target_encoder + self.target_encoder = None + + def update_for_step(self, step, __): + assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder' + assert self.target_encoder is not None, 'target encoder has not been created yet' + update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) + + def get_debug_values(self, step, __): + # In the BYOL paper, this is made to increase over time. Not yet implemented, but still logging the value. + return {'target_ema_beta': self.target_ema_updater.beta} + + def visual_dbg(self, step, path): + if self.do_aug: + torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,))) + torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,))) + + def forward(self, image_one, image_two): + if self.do_aug: + image_one = self.aug(image_one) + image_two = self.aug(image_two) + # Keep copies on hand for visual_dbg. + self.im1 = image_one.detach().copy() + self.im2 = image_two.detach().copy() + + online_proj_one = self.online_encoder(image_one) + online_proj_two = self.online_encoder(image_two) + + online_pred_one = self.online_predictor(online_proj_one) + online_pred_two = self.online_predictor(online_proj_two) + + with torch.no_grad(): + target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder + target_proj_one = target_encoder(image_one).detach() + target_proj_two = target_encoder(image_two).detach() + + loss_one = loss_fn(online_pred_one, target_proj_two.detach()) + loss_two = loss_fn(online_pred_two, target_proj_one.detach()) + + loss = loss_one + loss_two + return loss.mean() + + +class PointwiseAugmentor(nn.Module): + def __init__(self, img_size=224): + super().__init__() + self.jitter = RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8) + self.gray = augs.RandomGrayscale(p=0.2) + self.blur = RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1) + self.rrc = augs.RandomResizedCrop((img_size, img_size)) + + # Given a point in the *destination* image, returns the same point in the source image, given the kornia RRC params. + def reverse_rrc(self, dest_point, params): + dh, dw = params['dst'][:,2,1]-params['dst'][:,0,1], params['dst'][:,2,0] - params['dst'][:,0,0] + sh, sw = params['src'][:,2,1]-params['src'][:,0,1], params['src'][:,2,0] - params['src'][:,0,0] + scale_h, scale_w = sh.float() / dh.float(), sw.float() / dw.float() + t, l = dest_point + t = (t.float() * scale_h).int() + l = (l.float() * scale_w).int() + return t + params['src'][:,0,1], l + params['src'][:,0,0] + + def reverse_horizontal_flip(self, pt, input): + t, l = pt + center = input.shape[-1] // 2 + return t, 2 * center - l + + def forward(self, x, points): + d = self.jitter(x) + d = self.gray(d) + will_flip = random.random() > .5 + if will_flip: + d = apply_hflip(d) + d = self.blur(d) + params = self.rrc.generate_parameters(d.shape) + d = self.rrc(d, params=params) + + rev = self.reverse_rrc(points, params) + if will_flip: + rev = self.reverse_horizontal_flip(rev, x) + +if __name__ == '__main__': + p = PointwiseAugmentor(256) + t = ToTensor()(Image.open('E:\\4k6k\\datasets\\ns_images\\imagesets\\000001_152761.jpg')).unsqueeze(0).repeat(8,1,1,1) + points = (torch.randint(0,224,(t.shape[0],)),torch.randint(0,224,(t.shape[0],))) + p(t, points) + + +@register_model +def register_byol(opt_net, opt): + subnet = create_model(opt, opt_net['subnet']) + return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], + structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False), + do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False)) \ No newline at end of file diff --git a/codes/models/segformer/backbone.py b/codes/models/segformer/backbone.py new file mode 100644 index 00000000..0c2a06f1 --- /dev/null +++ b/codes/models/segformer/backbone.py @@ -0,0 +1,132 @@ +# A direct copy of torchvision's resnet.py modified to support gradient checkpointing. + +import torch +import torch.nn as nn +from torchvision.models.resnet import BasicBlock, Bottleneck +from torchvision.models.utils import load_state_dict_from_url +import torchvision + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + +from trainer.networks import register_model +from utils.util import checkpoint + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +class Backbone(torchvision.models.resnet.ResNet): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, + replace_stride_with_dilation, norm_layer) + del self.fc + del self.avgpool + + def _forward_impl(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + l1 = checkpoint(self.layer1, x) + l2 = checkpoint(self.layer2, l1) + l3 = checkpoint(self.layer3, l2) + l4 = checkpoint(self.layer4, l3) + + return l1, l2, l3, l4 + + def forward(self, x): + return self._forward_impl(x) + + +def _backbone(arch, block, layers, pretrained, progress, **kwargs): + model = Backbone(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def backbone18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _backbone('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def backbone34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _backbone('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def backbone50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _backbone('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def backbone101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _backbone('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def backbone152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +@register_model +def register_resnet50(opt_net, opt): + model = resnet50(pretrained=opt_net['pretrained']) + if opt_net['custom_head_logits']: + model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits']) + return model + diff --git a/codes/models/segformer/segformer.py b/codes/models/segformer/segformer.py new file mode 100644 index 00000000..b5a061c9 --- /dev/null +++ b/codes/models/segformer/segformer.py @@ -0,0 +1,89 @@ +import math + +import torch +import torch.nn as nn +from tqdm import tqdm + +from models.segformer.backbone import backbone50 + + +class DilatorModule(nn.Module): + def __init__(self, input_channels, output_channels, max_dilation): + super().__init__() + self.max_dilation = max_dilation + self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, dilation=1, bias=True) + if max_dilation > 1: + self.bn = nn.BatchNorm2d(input_channels) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, dilation=max_dilation, bias=True) + self.dense = nn.Linear(input_channels, output_channels, bias=True) + + def forward(self, inp, loc): + x = self.conv1(inp) + if self.max_dilation > 1: + x = self.bn(self.relu(x)) + x = self.conv2(x) + + # This can be made (possibly substantially) more efficient by only computing these convolutions across a subset of the image. Possibly. + i, j = loc + x = x[:,:,i,j] + return self.dense(x) + + +# Grabbed from torch examples: https://github.com/pytorch/examples/tree/master/https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65:7 +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEncoding, self).__init__() + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:x.size(0), :] + return x + + +class Segformer(nn.Module): + def __init__(self): + super().__init__() + self.backbone = backbone50() + backbone_channels = [256, 512, 1024, 2048] + dilations = [[1,2,3,4],[1,2,3],[1,2],[1]] + final_latent_channels = 2048 + dilators = [] + for ic, dis in zip(backbone_channels, dilations): + layer_dilators = [] + for di in dis: + layer_dilators.append(DilatorModule(ic, final_latent_channels, di)) + dilators.append(nn.ModuleList(layer_dilators)) + self.dilators = nn.ModuleList(dilators) + + self.token_position_encoder = PositionalEncoding(final_latent_channels, max_len=10) + self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(16)]) + + def forward(self, x, pos): + layers = self.backbone(x) + set = [] + i, j = pos[0] // 4, pos[1] // 4 + for layer_out, dilator in zip(layers, self.dilators): + for subdilator in dilator: + set.append(subdilator(layer_out, (i, j))) + i, j = i // 2, j // 2 + + # The torch transformer expects the set dimension to be 0. + set = torch.stack(set, dim=0) + set = self.token_position_encoder(set) + set = self.transformer_layers(set) + return set + + +if __name__ == '__main__': + model = Segformer().to('cuda') + for j in tqdm(range(1000)): + test_tensor = torch.randn(64,3,224,224).cuda() + model(test_tensor, (43, 73)) \ No newline at end of file