diff --git a/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py b/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py index 7c4f5199..a40be66c 100644 --- a/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py +++ b/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py @@ -446,15 +446,19 @@ class PixelCL(nn.Module): image_one_cutout.shape, self.cutout_interpolate_mode) # If max_latent_dim is specified, stochastically extract latents from the shared areas. - _, _, pp_h, pp_w = proj_pixel_one.shape + b, c, pp_h, pp_w = proj_pixel_one.shape if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim: prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim)) latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two] extracted = [] for l in latents: - l = l.view(-1, -1, pp_h * pp_w) - l = l[prob.multinomial(num_samples=self.max_latent_dim, replacement=False)] - extracted.append(l.reshape(-1, -1, math.sqrt(self.max_latent_dim), math.sqrt(self.max_latent_dim))) + l = l.reshape(b, c, pp_h * pp_w) + l = l[:, :, prob.multinomial(num_samples=self.max_latent_dim, replacement=False)] + # For compatibility with the existing pixpro code, reshape this stochastic sampling back into a 2d "square". + # Note that the actual structure no longer matters going forwards. Pixels are only compared to themselves and others without regards + # to structure. + sqdim = int(math.sqrt(self.max_latent_dim)) + extracted.append(l.reshape(b, c, sqdim, sqdim)) proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted # flatten all the pixel projections