import math import copy import os import random from functools import wraps, partial from math import floor import torch import torchvision from torch import nn, einsum import torch.nn.functional as F from kornia import augmentation as augs from kornia import filters, color from einops import rearrange # helper functions from trainer.networks import register_model, create_model def identity(t): return t def default(val, def_val): return def_val if val is None else val def rand_true(prob): return random.random() < prob def singleton(cache_key): def inner_fn(fn): @wraps(fn) def wrapper(self, *args, **kwargs): instance = getattr(self, cache_key) if instance is not None: return instance instance = fn(self, *args, **kwargs) setattr(self, cache_key, instance) return instance return wrapper return inner_fn def get_module_device(module): return next(module.parameters()).device def set_requires_grad(model, val): for p in model.parameters(): p.requires_grad = val def cutout_coordinates(image, ratio_range = (0.6, 0.8)): _, _, orig_h, orig_w = image.shape ratio_lo, ratio_hi = ratio_range random_ratio = ratio_lo + random.random() * (ratio_hi - ratio_lo) w, h = floor(random_ratio * orig_w), floor(random_ratio * orig_h) coor_x = floor((orig_w - w) * random.random()) coor_y = floor((orig_h - h) * random.random()) return ((coor_y, coor_y + h), (coor_x, coor_x + w)), random_ratio def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'): shape = image.shape output_size = default(output_size, shape[2:]) (y0, y1), (x0, x1) = coordinates cutout_image = image[:, :, y0:y1, x0:x1] return F.interpolate(cutout_image, size = output_size, mode = mode) def scale_coords(coords, scale): output = [[0,0],[0,0]] for j in range(2): for k in range(2): output[j][k] = int(coords[j][k] / scale) return output def reverse_cutout_and_resize(image, coordinates, scale_reduction, mode = 'nearest'): blank = torch.zeros_like(image) coordinates = scale_coords(coordinates, scale_reduction) (y0, y1), (x0, x1) = coordinates orig_cutout_shape = (y1-y0, x1-x0) if orig_cutout_shape[0] <= 0 or orig_cutout_shape[1] <= 0: return None un_resized_img = F.interpolate(image, size=orig_cutout_shape, mode=mode) blank[:,:,y0:y1,x0:x1] = un_resized_img return blank def compute_shared_coords(coords1, coords2, scale_reduction): (y1_t, y1_b), (x1_l, x1_r) = scale_coords(coords1, scale_reduction) (y2_t, y2_b), (x2_l, x2_r) = scale_coords(coords2, scale_reduction) shared = ((max(y1_t, y2_t), min(y1_b, y2_b)), (max(x1_l, x2_l), min(x1_r, x2_r))) for s in shared: if s == 0: return None return shared def get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, img_orig_shape, interp_mode): # Unflip the pixel projections proj_pixel_one = flip_image_one_fn(proj_pixel_one) proj_pixel_two = flip_image_two_fn(proj_pixel_two) # Undo the cutout and resize, taking into account the scale reduction applied by the encoder. scale_reduction = proj_pixel_one.shape[-1] / img_orig_shape[-1] proj_pixel_one = reverse_cutout_and_resize(proj_pixel_one, cutout_coordinates_one, scale_reduction, mode=interp_mode) proj_pixel_two = reverse_cutout_and_resize(proj_pixel_two, cutout_coordinates_two, scale_reduction, mode=interp_mode) if proj_pixel_one is None or proj_pixel_two is None: print("Could not extract projected image region. The selected cutout coordinates were smaller than the aggregate size of one latent block!") return None # Compute the shared coordinates for the two cutouts: shared_coords = compute_shared_coords(cutout_coordinates_one, cutout_coordinates_two, scale_reduction) if shared_coords is None: print("No shared coordinates for this iteration (probably should just recompute those coordinates earlier..") return None (yt, yb), (xl, xr) = shared_coords return proj_pixel_one[:, :, yt:yb, xl:xr], proj_pixel_two[:, :, yt:yb, xl:xr] # augmentation utils class RandomApply(nn.Module): def __init__(self, fn, p): super().__init__() self.fn = fn self.p = p def forward(self, x): if random.random() > self.p: return x return self.fn(x) # exponential moving average class EMA(): def __init__(self, beta): super().__init__() self.beta = beta def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new def update_moving_average(ema_updater, ma_model, current_model): for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = ema_updater.update_average(old_weight, up_weight) # loss fn def loss_fn(x, y): x = F.normalize(x, dim=-1, p=2) y = F.normalize(y, dim=-1, p=2) return 2 - 2 * (x * y).sum(dim=-1) # classes class MLP(nn.Module): def __init__(self, chan, chan_out = 256, inner_dim = 2048): super().__init__() self.net = nn.Sequential( nn.Linear(chan, inner_dim), nn.BatchNorm1d(inner_dim), nn.ReLU(), nn.Linear(inner_dim, chan_out) ) def forward(self, x): return self.net(x) class ConvMLP(nn.Module): def __init__(self, chan, chan_out = 256, inner_dim = 2048): super().__init__() self.net = nn.Sequential( nn.Conv2d(chan, inner_dim, 1), nn.BatchNorm2d(inner_dim), nn.ReLU(), nn.Conv2d(inner_dim, chan_out, 1) ) def forward(self, x): return self.net(x) class PPM(nn.Module): def __init__( self, *, chan, num_layers = 1, gamma = 2): super().__init__() self.gamma = gamma if num_layers == 0: self.transform_net = nn.Identity() elif num_layers == 1: self.transform_net = nn.Conv2d(chan, chan, 1) elif num_layers == 2: self.transform_net = nn.Sequential( nn.Conv2d(chan, chan, 1), nn.BatchNorm2d(chan), nn.ReLU(), nn.Conv2d(chan, chan, 1) ) else: raise ValueError('num_layers must be one of 0, 1, or 2') def forward(self, x): xi = x[:, :, :, :, None, None] xj = x[:, :, None, None, :, :] similarity = F.relu(F.cosine_similarity(xi, xj, dim = 1)) ** self.gamma transform_out = self.transform_net(x) out = einsum('b x y h w, b c h w -> b c x y', similarity, transform_out) return out # a wrapper class for the base neural network # will manage the interception of the hidden layer output # and pipe it into the projecter and predictor nets class NetWrapper(nn.Module): def __init__( self, *, net, instance_projection_size, instance_projection_hidden_size, pix_projection_size, pix_projection_hidden_size, layer_pixel = -2, layer_instance = -2 ): super().__init__() self.net = net self.layer_pixel = layer_pixel self.layer_instance = layer_instance self.pixel_projector = None self.instance_projector = None self.instance_projection_size = instance_projection_size self.instance_projection_hidden_size = instance_projection_hidden_size self.pix_projection_size = pix_projection_size self.pix_projection_hidden_size = pix_projection_hidden_size self.hidden_pixel = None self.hidden_instance = None self.hook_registered = False def _find_layer(self, layer_id): if type(layer_id) == str: modules = dict([*self.net.named_modules()]) return modules.get(layer_id, None) elif type(layer_id) == int: children = [*self.net.children()] return children[layer_id] return None def _hook(self, attr_name, _, __, output): setattr(self, attr_name, output) def _register_hook(self): pixel_layer = self._find_layer(self.layer_pixel) instance_layer = self._find_layer(self.layer_instance) assert pixel_layer is not None, f'hidden layer ({self.layer_pixel}) not found' assert instance_layer is not None, f'hidden layer ({self.layer_instance}) not found' pixel_layer.register_forward_hook(partial(self._hook, 'hidden_pixel')) instance_layer.register_forward_hook(partial(self._hook, 'hidden_instance')) self.hook_registered = True @singleton('pixel_projector') def _get_pixel_projector(self, hidden): _, dim, *_ = hidden.shape projector = ConvMLP(dim, self.pix_projection_size, self.pix_projection_hidden_size) return projector.to(hidden) @singleton('instance_projector') def _get_instance_projector(self, hidden): _, dim = hidden.shape projector = MLP(dim, self.instance_projection_size, self.instance_projection_hidden_size) return projector.to(hidden) def get_representation(self, x): if not self.hook_registered: self._register_hook() _ = self.net(x) hidden_pixel = self.hidden_pixel hidden_instance = self.hidden_instance self.hidden_pixel = None self.hidden_instance = None assert hidden_pixel is not None, f'hidden pixel layer {self.layer_pixel} never emitted an output' assert hidden_instance is not None, f'hidden instance layer {self.layer_instance} never emitted an output' return hidden_pixel, hidden_instance def forward(self, x): pixel_representation, instance_representation = self.get_representation(x) instance_representation = instance_representation.flatten(1) pixel_projector = self._get_pixel_projector(pixel_representation) instance_projector = self._get_instance_projector(instance_representation) pixel_projection = pixel_projector(pixel_representation) instance_projection = instance_projector(instance_representation) return pixel_projection, instance_projection # main class class PixelCL(nn.Module): def __init__( self, net, image_size, hidden_layer_pixel = -2, hidden_layer_instance = -2, instance_projection_size = 256, instance_projection_hidden_size = 2048, pix_projection_size = 256, pix_projection_hidden_size = 2048, augment_fn = None, augment_fn2 = None, prob_rand_hflip = 0.25, moving_average_decay = 0.99, ppm_num_layers = 1, ppm_gamma = 2, distance_thres = 0.7, similarity_temperature = 0.3, cutout_ratio_range = (0.6, 0.8), cutout_interpolate_mode = 'nearest', coord_cutout_interpolate_mode = 'bilinear', max_latent_dim = None # When set, this is the number of stochastically extracted pixels from the latent to extract. Must have an integer square root. ): super().__init__() DEFAULT_AUG = nn.Sequential( RandomApply(augs.ColorJitter(0.6, 0.6, 0.6, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augs.RandomSolarize(p=0.5), # Normalize left out because it should be done at the model level. ) self.augment1 = default(augment_fn, DEFAULT_AUG) self.augment2 = default(augment_fn2, self.augment1) self.prob_rand_hflip = prob_rand_hflip self.online_encoder = NetWrapper( net = net, instance_projection_size = instance_projection_size, instance_projection_hidden_size = instance_projection_hidden_size, pix_projection_size = pix_projection_size, pix_projection_hidden_size = pix_projection_hidden_size, layer_pixel = hidden_layer_pixel, layer_instance = hidden_layer_instance ) self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) self.distance_thres = distance_thres self.similarity_temperature = similarity_temperature # This requirement is due to the way that these are processed, not a hard requirement. assert math.sqrt(max_latent_dim) == int(math.sqrt(max_latent_dim)) self.max_latent_dim = max_latent_dim self.propagate_pixels = PPM( chan = pix_projection_size, num_layers = ppm_num_layers, gamma = ppm_gamma ) self.cutout_ratio_range = cutout_ratio_range self.cutout_interpolate_mode = cutout_interpolate_mode self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode # instance level predictor self.online_predictor = MLP(instance_projection_size, instance_projection_size, instance_projection_hidden_size) # get device of network and make wrapper same device device = get_module_device(net) self.to(device) # send a mock image tensor to instantiate singleton parameters self.forward(torch.randn(2, 3, image_size, image_size, device=device)) @singleton('target_encoder') def _get_target_encoder(self): target_encoder = copy.deepcopy(self.online_encoder) set_requires_grad(target_encoder, False) return target_encoder def reset_moving_average(self): del self.target_encoder self.target_encoder = None def update_moving_average(self): assert self.target_encoder is not None, 'target encoder has not been created yet' update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) def forward(self, x): shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip rand_flip_fn = lambda t: torch.flip(t, dims = (-1,)) flip_image_one, flip_image_two = rand_true(prob_flip), rand_true(prob_flip) flip_image_one_fn = rand_flip_fn if flip_image_one else identity flip_image_two_fn = rand_flip_fn if flip_image_two else identity cutout_coordinates_one, _ = cutout_coordinates(x, self.cutout_ratio_range) cutout_coordinates_two, _ = cutout_coordinates(x, self.cutout_ratio_range) image_one_cutout = cutout_and_resize(x, cutout_coordinates_one, mode = self.cutout_interpolate_mode) image_two_cutout = cutout_and_resize(x, cutout_coordinates_two, mode = self.cutout_interpolate_mode) image_one_cutout = flip_image_one_fn(image_one_cutout) image_two_cutout = flip_image_two_fn(image_two_cutout) image_one_cutout, image_two_cutout = self.augment1(image_one_cutout), self.augment2(image_two_cutout) self.aug1 = image_one_cutout.detach().clone() self.aug2 = image_two_cutout.detach().clone() proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout) proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout) proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, image_one_cutout.shape, self.cutout_interpolate_mode) if proj_pixel_one is None or proj_pixel_two is None: positive_pixel_pairs = 0 else: positive_pixel_pairs = proj_pixel_one.shape[-1] * proj_pixel_one.shape[-2] with torch.no_grad(): target_encoder = self._get_target_encoder() target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout) target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout) target_proj_pixel_one, target_proj_pixel_two = get_shared_region(target_proj_pixel_one, target_proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, image_one_cutout.shape, self.cutout_interpolate_mode) # If max_latent_dim is specified, stochastically extract latents from the shared areas. b, c, pp_h, pp_w = proj_pixel_one.shape if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim: prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim)) latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two] extracted = [] for l in latents: l = l.reshape(b, c, pp_h * pp_w) l = l[:, :, prob.multinomial(num_samples=self.max_latent_dim, replacement=False)] # For compatibility with the existing pixpro code, reshape this stochastic sampling back into a 2d "square". # Note that the actual structure no longer matters going forwards. Pixels are only compared to themselves and others without regards # to the original image structure. sqdim = int(math.sqrt(self.max_latent_dim)) extracted.append(l.reshape(b, c, sqdim, sqdim)) proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted # flatten all the pixel projections flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)') target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two))) # get instance level loss pred_instance_one = self.online_predictor(proj_instance_one) pred_instance_two = self.online_predictor(proj_instance_two) loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach()) loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach()) instance_loss = (loss_instance_one + loss_instance_two).mean() if positive_pixel_pairs == 0: return instance_loss, 0 # calculate pix pro loss propagated_pixels_one = self.propagate_pixels(proj_pixel_one) propagated_pixels_two = self.propagate_pixels(proj_pixel_two) propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two))) propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) loss_pixpro_one_two = - propagated_similarity_one_two.mean() loss_pixpro_two_one = - propagated_similarity_two_one.mean() pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2 return instance_loss, pix_loss, positive_pixel_pairs # Allows visualizing what the augmentor is up to. def visual_dbg(self, step, path): if not hasattr(self, 'aug1'): return torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,))) torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,))) @register_model def register_pixel_contrastive_learner(opt_net, opt): subnet = create_model(opt, opt_net['subnet']) kwargs = opt_net['kwargs'] if 'subnet_pretrain_path' in opt_net.keys(): sd = torch.load(opt_net['subnet_pretrain_path']) subnet.load_state_dict(sd, strict=False) return PixelCL(subnet, **kwargs)