From 19475a072fc8bccaca411a519ad2674b6a692f0b Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 13 Jan 2021 11:26:51 -0700 Subject: [PATCH] Pixpro: Rather than using a latent square for pixpro, use an entirely stochastic sampling of the pixels --- .../pixpro_lucidrains.py | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py b/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py index 1536a6a7..7c4f5199 100644 --- a/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py +++ b/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py @@ -332,7 +332,7 @@ class PixelCL(nn.Module): cutout_ratio_range = (0.6, 0.8), cutout_interpolate_mode = 'nearest', coord_cutout_interpolate_mode = 'bilinear', - max_latent_dim = None # This is in latent space, not image space, so dimensionality reduction of your network must be accounted for. + max_latent_dim = None # When set, this is the number of stochastically extracted pixels from the latent to extract. Must have an integer square root. ): super().__init__() @@ -364,6 +364,9 @@ class PixelCL(nn.Module): self.distance_thres = distance_thres self.similarity_temperature = similarity_temperature self.alpha = alpha + + # This requirement is due to the way that these are processed, not a hard requirement. + assert math.sqrt(max_latent_dim) == int(math.sqrt(max_latent_dim)) self.max_latent_dim = max_latent_dim self.propagate_pixels = PPM( @@ -429,9 +432,6 @@ class PixelCL(nn.Module): proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, image_one_cutout.shape, self.cutout_interpolate_mode) - sim_region_img_one, sim_region_img_two = get_shared_region(image_one_cutout, image_two_cutout, cutout_coordinates_one, - cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, - image_one_cutout.shape, self.cutout_interpolate_mode) if proj_pixel_one is None or proj_pixel_two is None: positive_pixel_pairs = 0 else: @@ -445,25 +445,17 @@ class PixelCL(nn.Module): cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, image_one_cutout.shape, self.cutout_interpolate_mode) - # Apply max_latent_dim if needed. + # If max_latent_dim is specified, stochastically extract latents from the shared areas. _, _, pp_h, pp_w = proj_pixel_one.shape - if self.max_latent_dim and pp_h > self.max_latent_dim: - margin = pp_h - self.max_latent_dim - loc = random.randint(0, margin) - loce = loc + self.max_latent_dim - proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, loc:loce, :], proj_pixel_two[:, :, loc:loce, :] - target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, loc:loce, :], target_proj_pixel_two[:, :, loc:loce, :] - sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, loc:loce, :], sim_region_img_two[:, :, loc:loce, :] - if self.max_latent_dim and pp_w > self.max_latent_dim: - margin = pp_w - self.max_latent_dim - loc = random.randint(0, margin) - loce = loc + self.max_latent_dim - proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, :, loc:loce], proj_pixel_two[:, :, :, loc:loce] - target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, :, loc:loce], target_proj_pixel_two[:, :, :, loc:loce] - sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, :, loc:loce], sim_region_img_two[:, :, :, loc:loce] - # Stash these away for debugging purposes. - self.sim_region_img_one = sim_region_img_one.detach().clone() - self.sim_region_img_two = sim_region_img_two.detach().clone() + if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim: + prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim)) + latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two] + extracted = [] + for l in latents: + l = l.view(-1, -1, pp_h * pp_w) + l = l[prob.multinomial(num_samples=self.max_latent_dim, replacement=False)] + extracted.append(l.reshape(-1, -1, math.sqrt(self.max_latent_dim), math.sqrt(self.max_latent_dim))) + proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted # flatten all the pixel projections flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)') @@ -503,8 +495,6 @@ class PixelCL(nn.Module): return torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,))) torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,))) - torchvision.utils.save_image(self.sim_region_img_one, os.path.join(path, "%i_sim1.png" % (step,))) - torchvision.utils.save_image(self.sim_region_img_two, os.path.join(path, "%i_sim2.png" % (step,))) @register_model