diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index d1b69edb..61ac2520 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -37,6 +37,8 @@ class ImageFolderDataset: if 'normalize' in opt.keys(): if opt['normalize'] == 'stylegan2_norm': self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + elif opt['normalize'] == 'imagenet': + self.normalize = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True) else: raise Exception('Unsupported normalize') else: diff --git a/codes/models/byol/byol_for_semantic_chaining.py b/codes/models/byol/byol_for_semantic_chaining.py index b92cbc28..8de0846e 100644 --- a/codes/models/byol/byol_for_semantic_chaining.py +++ b/codes/models/byol/byol_for_semantic_chaining.py @@ -51,6 +51,54 @@ def set_requires_grad(model, val): p.requires_grad = val +# Specialized augmentor class that applies a set of image transformations on points as well, allowing one to track +# where a point in the src image is located in the dest image. Restricts transformation such that this is possible. +class PointwiseAugmentor(nn.Module): + def __init__(self, img_size=224): + super().__init__() + self.jitter = RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8) + self.gray = augs.RandomGrayscale(p=0.2) + self.blur = RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1) + self.rrc = augs.RandomResizedCrop((img_size, img_size), same_on_batch=True) + + # Given a point in the source image, returns the same point in the source image, given the kornia RRC params. + def rrc_on_point(self, src_point, params): + dh, dw = params['dst'][:,2,1]-params['dst'][:,0,1], params['dst'][:,2,0] - params['dst'][:,0,0] + sh, sw = params['src'][:,2,1]-params['src'][:,0,1], params['src'][:,2,0] - params['src'][:,0,0] + scale_h, scale_w = sh.float() / dh.float(), sw.float() / dw.float() + t, l = src_point[0] - params['src'][0,0,1], src_point[1] - params['src'][0,0,0] + t = (t.float() / scale_h[0]).long() + l = (l.float() / scale_w[0]).long() + return torch.stack([t,l]) + + def flip_on_point(self, pt, input): + t, l = pt[0], pt[1] + center = input.shape[-1] // 2 + return t, 2 * center - l + + def forward(self, x, point): + d = self.jitter(x) + d = self.gray(d) + will_flip = random.random() > .5 + if will_flip: + d = apply_hflip(d) + point = self.flip_on_point(point, x) + d = self.blur(d) + + invalid = True + while invalid: + params = self.rrc.generate_parameters(d.shape) + potential = self.rrc_on_point(point, params) + # '10' is an arbitrary number: we want to provide some margin. Making predictions at the very edge of an image is not very useful. + if potential[0] <= 10 or potential[1] <= 10 or potential[0] > x.shape[-2]-10 or potential[1] > x.shape[-1]-10: + continue + d = self.rrc(d, params=params) + point = potential + invalid = False + + return d, point + + # loss fn def loss_fn(x, y): x = F.normalize(x, dim=-1, p=2) @@ -160,21 +208,21 @@ class NetWrapper(nn.Module): projector = MLP(dim, self.projection_size, self.projection_hidden_size) return projector.to(hidden) - def get_representation(self, x): + def get_representation(self, x, pt): if self.layer == -1: - return self.net(x) + return self.net(x, pt) if not self.hook_registered: self._register_hook() - unused = self.net(x) + unused = self.net(x, pt) hidden = self.hidden self.hidden = None assert hidden is not None, f'hidden layer {self.layer} never emitted an output' return hidden - def forward(self, x): - representation = self.get_representation(x) + def forward(self, x, pt): + representation = self.get_representation(x, pt) projector = self._get_projector(representation) projection = checkpoint(projector, representation) return projection @@ -191,24 +239,13 @@ class BYOL(nn.Module): moving_average_decay=0.99, use_momentum=True, structural_mlp=False, - do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes - # this can overwhelm the CPU though, and it becomes desirable to do the augmentations - # on the GPU again. ): super().__init__() self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, use_structural_mlp=structural_mlp) - self.do_aug = do_augmentation - if self.do_aug: - augmentations = [ \ - RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), - augs.RandomGrayscale(p=0.2), - augs.RandomHorizontalFlip(), - RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), - augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] - self.aug = nn.Sequential(*augmentations) + self.aug = PointwiseAugmentor(image_size) self.use_momentum = use_momentum self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) @@ -220,8 +257,7 @@ class BYOL(nn.Module): self.to(device) # send a mock image tensor to instantiate singleton parameters - self.forward(torch.randn(2, 3, image_size, image_size, device=device), - torch.randn(2, 3, image_size, image_size, device=device)) + self.forward(torch.randn(2, 3, image_size, image_size, device=device)) @singleton('target_encoder') def _get_target_encoder(self): @@ -245,28 +281,32 @@ class BYOL(nn.Module): return {'target_ema_beta': self.target_ema_updater.beta} def visual_dbg(self, step, path): - if self.do_aug: - torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,))) - torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,))) + torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,))) + torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,))) - def forward(self, image_one, image_two): - if self.do_aug: - image_one = self.aug(image_one) - image_two = self.aug(image_two) - # Keep copies on hand for visual_dbg. - self.im1 = image_one.detach().copy() - self.im2 = image_two.detach().copy() + def forward(self, image): + _, _, h, w = image.shape + point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device) - online_proj_one = self.online_encoder(image_one) - online_proj_two = self.online_encoder(image_two) + image_one, pt_one = self.aug(image, point) + image_two, pt_two = self.aug(image, point) + + # Keep copies on hand for visual_dbg. + self.im1 = image_one.detach().clone() + self.im1[:,:,pt_one[0]-3:pt_one[0]+3,pt_one[1]-3:pt_one[1]+3] = 1 + self.im2 = image_two.detach().clone() + self.im2[:,:,pt_two[0]-3:pt_two[0]+3,pt_two[1]-3:pt_two[1]+3] = 1 + + online_proj_one = self.online_encoder(image_one, pt_one) + online_proj_two = self.online_encoder(image_two, pt_two) online_pred_one = self.online_predictor(online_proj_one) online_pred_two = self.online_predictor(online_proj_two) with torch.no_grad(): target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder - target_proj_one = target_encoder(image_one).detach() - target_proj_two = target_encoder(image_two).detach() + target_proj_one = target_encoder(image_one, pt_one).detach() + target_proj_two = target_encoder(image_two, pt_two).detach() loss_one = loss_fn(online_pred_one, target_proj_two.detach()) loss_two = loss_fn(online_pred_two, target_proj_one.detach()) @@ -275,53 +315,20 @@ class BYOL(nn.Module): return loss.mean() -class PointwiseAugmentor(nn.Module): - def __init__(self, img_size=224): - super().__init__() - self.jitter = RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8) - self.gray = augs.RandomGrayscale(p=0.2) - self.blur = RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1) - self.rrc = augs.RandomResizedCrop((img_size, img_size)) - - # Given a point in the *destination* image, returns the same point in the source image, given the kornia RRC params. - def reverse_rrc(self, dest_point, params): - dh, dw = params['dst'][:,2,1]-params['dst'][:,0,1], params['dst'][:,2,0] - params['dst'][:,0,0] - sh, sw = params['src'][:,2,1]-params['src'][:,0,1], params['src'][:,2,0] - params['src'][:,0,0] - scale_h, scale_w = sh.float() / dh.float(), sw.float() / dw.float() - t, l = dest_point - t = (t.float() * scale_h).int() - l = (l.float() * scale_w).int() - return t + params['src'][:,0,1], l + params['src'][:,0,0] - - def reverse_horizontal_flip(self, pt, input): - t, l = pt - center = input.shape[-1] // 2 - return t, 2 * center - l - - def forward(self, x, points): - d = self.jitter(x) - d = self.gray(d) - will_flip = random.random() > .5 - if will_flip: - d = apply_hflip(d) - d = self.blur(d) - params = self.rrc.generate_parameters(d.shape) - d = self.rrc(d, params=params) - - rev = self.reverse_rrc(points, params) - if will_flip: - rev = self.reverse_horizontal_flip(rev, x) - if __name__ == '__main__': - p = PointwiseAugmentor(256) - t = ToTensor()(Image.open('E:\\4k6k\\datasets\\ns_images\\imagesets\\000001_152761.jpg')).unsqueeze(0).repeat(8,1,1,1) - points = (torch.randint(0,224,(t.shape[0],)),torch.randint(0,224,(t.shape[0],))) - p(t, points) + pa = PointwiseAugmentor(256) + for j in range(100): + t = ToTensor()(Image.open('E:\\4k6k\\datasets\\ns_images\\imagesets\\000001_152761.jpg')).unsqueeze(0).repeat(8,1,1,1) + p = torch.randint(50,180,(2,)) + augmented, dp = pa(t, p) + t, p = pa(t, p) + t[:,:,p[0]-3:p[0]+3,p[1]-3:p[1]+3] = 0 + torchvision.utils.save_image(t, f"{j}_src.png") + augmented[:,:,dp[0]-3:dp[0]+3,dp[1]-3:dp[1]+3] = 0 + torchvision.utils.save_image(augmented, f"{j}_dst.png") @register_model -def register_byol(opt_net, opt): +def register_pixel_local_byol(opt_net, opt): subnet = create_model(opt, opt_net['subnet']) - return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], - structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False), - do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False)) \ No newline at end of file + return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer']) \ No newline at end of file diff --git a/codes/models/segformer/__init__.py b/codes/models/segformer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/segformer/segformer.py b/codes/models/segformer/segformer.py index b5a061c9..a2bdb103 100644 --- a/codes/models/segformer/segformer.py +++ b/codes/models/segformer/segformer.py @@ -2,11 +2,30 @@ import math import torch import torch.nn as nn +import torchvision from tqdm import tqdm from models.segformer.backbone import backbone50 +# torch.gather() which operates as it always fucking should have: pulling indexes from the input. +from trainer.networks import register_model + + +def gather_2d(input, index): + b, c, h, w = input.shape + nodim = input.view(b, c, h * w) + ind_nd = index[:, 0]*w + index[:, 1] + ind_nd = ind_nd.unsqueeze(1) + ind_nd = ind_nd.repeat((1, c)) + ind_nd = ind_nd.unsqueeze(2) + result = torch.gather(nodim, dim=2, index=ind_nd) + result = result.squeeze() + if b == 1: + result = result.unsqueeze(0) + return result + + class DilatorModule(nn.Module): def __init__(self, input_channels, output_channels, max_dilation): super().__init__() @@ -15,7 +34,7 @@ class DilatorModule(nn.Module): if max_dilation > 1: self.bn = nn.BatchNorm2d(input_channels) self.relu = nn.ReLU() - self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, dilation=max_dilation, bias=True) + self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=max_dilation, dilation=max_dilation, bias=True) self.dense = nn.Linear(input_channels, output_channels, bias=True) def forward(self, inp, loc): @@ -24,9 +43,8 @@ class DilatorModule(nn.Module): x = self.bn(self.relu(x)) x = self.conv2(x) - # This can be made (possibly substantially) more efficient by only computing these convolutions across a subset of the image. Possibly. - i, j = loc - x = x[:,:,i,j] + # This can be made more efficient by only computing these convolutions across a subset of the image. Possibly. + x = gather_2d(x, loc).contiguous() return self.dense(x) @@ -48,13 +66,22 @@ class PositionalEncoding(nn.Module): return x -class Segformer(nn.Module): +# Simple mean() layer encoded into a class so that BYOL can grab it. +class Tail(nn.Module): def __init__(self): super().__init__() + + def forward(self, x): + return x.mean(dim=0) + + +class Segformer(nn.Module): + def __init__(self, latent_channels=1024, layers=8): + super().__init__() self.backbone = backbone50() backbone_channels = [256, 512, 1024, 2048] dilations = [[1,2,3,4],[1,2,3],[1,2],[1]] - final_latent_channels = 2048 + final_latent_channels = latent_channels dilators = [] for ic, dis in zip(backbone_channels, dilations): layer_dilators = [] @@ -64,26 +91,37 @@ class Segformer(nn.Module): self.dilators = nn.ModuleList(dilators) self.token_position_encoder = PositionalEncoding(final_latent_channels, max_len=10) - self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(16)]) + self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(layers)]) + self.tail = Tail() def forward(self, x, pos): layers = self.backbone(x) set = [] - i, j = pos[0] // 4, pos[1] // 4 + + # A single position can be optionally given, in which case we need to expand it to represent the entire input. + if pos.shape == (2,): + pos = pos.unsqueeze(0).repeat(x.shape[0],1) + + pos = pos // 4 for layer_out, dilator in zip(layers, self.dilators): for subdilator in dilator: - set.append(subdilator(layer_out, (i, j))) - i, j = i // 2, j // 2 + set.append(subdilator(layer_out, pos)) + pos = pos // 2 # The torch transformer expects the set dimension to be 0. set = torch.stack(set, dim=0) set = self.token_position_encoder(set) set = self.transformer_layers(set) - return set + return self.tail(set) + + +@register_model +def register_segformer(opt_net, opt): + return Segformer() if __name__ == '__main__': model = Segformer().to('cuda') for j in tqdm(range(1000)): test_tensor = torch.randn(64,3,224,224).cuda() - model(test_tensor, (43, 73)) \ No newline at end of file + print(model(test_tensor, torch.randint(0,224,(64,2)).cuda()).shape) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index 2fdfb4b3..b809a286 100644 --- a/codes/train.py +++ b/codes/train.py @@ -295,7 +295,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cats_stylegan2_rosinality.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_xx.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/eval/single_point_pair_contrastive_eval.py b/codes/trainer/eval/single_point_pair_contrastive_eval.py index fccebbd5..6a84cb34 100644 --- a/codes/trainer/eval/single_point_pair_contrastive_eval.py +++ b/codes/trainer/eval/single_point_pair_contrastive_eval.py @@ -11,6 +11,7 @@ import trainer.eval.evaluator as evaluator from pytorch_fid import fid_score from data.image_pair_with_corresponding_points_dataset import ImagePairWithCorrespondingPointsDataset +from models.segformer.segformer import Segformer from utils.util import opt_get # Uses two datasets: a "similar" and "dissimilar" dataset, each of which contains pairs of images and similar/dissimilar @@ -23,15 +24,17 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator): self.batch_sz = opt_eval['batch_size'] self.eval_qty = opt_eval['quantity'] assert self.eval_qty % self.batch_sz == 0 - self.similar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['similar_set_args']), shuffle=False, batch_size=self.batch_sz) - self.dissimilar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['dissimilar_set_args']), shuffle=False, batch_size=self.batch_sz) + self.similar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(opt_eval['similar_set_args']), shuffle=False, batch_size=self.batch_sz) + self.dissimilar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(opt_eval['dissimilar_set_args']), shuffle=False, batch_size=self.batch_sz) + # Hack to make this work with the BYOL generator. TODO: fix + self.model = self.model.online_encoder.net - def get_l2_score(self, dl): + def get_l2_score(self, dl, dev): distances = [] l2 = MSELoss() for i, data in tqdm(enumerate(dl)): - latent1 = self.model(data['img1'], data['coords1']) - latent2 = self.model(data['img2'], data['coords2']) + latent1 = self.model(data['img1'].to(dev), torch.stack(data['coords1'], dim=1).to(dev)) + latent2 = self.model(data['img2'].to(dev), torch.stack(data['coords2'], dim=1).to(dev)) distances.append(l2(latent1, latent2)) if i * self.batch_sz >= self.eval_qty: break @@ -40,9 +43,30 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator): def perform_eval(self): self.model.eval() - print("Computing contrastive eval on similar set") - similars = self.get_l2_score(self.similar_set) - print("Computing contrastive eval on dissimilar set") - dissimilars = self.get_l2_score(self.dissimilar_set) - print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}") - return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item()} \ No newline at end of file + with torch.no_grad(): + dev = next(self.model.parameters()).device + print("Computing contrastive eval on similar set") + similars = self.get_l2_score(self.similar_set, dev) + print("Computing contrastive eval on dissimilar set") + dissimilars = self.get_l2_score(self.dissimilar_set, dev) + diff = dissimilars.item() - similars.item() + print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}; val_diff: {diff}") + self.model.train() + return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item(), "val_diff": diff.item()} + + +if __name__ == '__main__': + model = Segformer(1024, 4).cuda() + eval = SinglePointPairContrastiveEval(model, { + 'batch_size': 8, + 'quantity': 32, + 'similar_set_args': { + 'path': 'E:\\4k6k\\datasets\\ns_images\\segformer_validation\\similar', + 'size': 256 + }, + 'dissimilar_set_args': { + 'path': 'E:\\4k6k\\datasets\\ns_images\\segformer_validation\\dissimilar', + 'size': 256 + }, + }, {}) + eval.perform_eval()