Fix pixpro stochastic sampling bugs

This commit is contained in:
James Betker 2021-01-13 11:34:24 -07:00
parent 19475a072f
commit 8990801a3f

View File

@ -446,15 +446,19 @@ class PixelCL(nn.Module):
image_one_cutout.shape, self.cutout_interpolate_mode)
# If max_latent_dim is specified, stochastically extract latents from the shared areas.
_, _, pp_h, pp_w = proj_pixel_one.shape
b, c, pp_h, pp_w = proj_pixel_one.shape
if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim:
prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim))
latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two]
extracted = []
for l in latents:
l = l.view(-1, -1, pp_h * pp_w)
l = l[prob.multinomial(num_samples=self.max_latent_dim, replacement=False)]
extracted.append(l.reshape(-1, -1, math.sqrt(self.max_latent_dim), math.sqrt(self.max_latent_dim)))
l = l.reshape(b, c, pp_h * pp_w)
l = l[:, :, prob.multinomial(num_samples=self.max_latent_dim, replacement=False)]
# For compatibility with the existing pixpro code, reshape this stochastic sampling back into a 2d "square".
# Note that the actual structure no longer matters going forwards. Pixels are only compared to themselves and others without regards
# to structure.
sqdim = int(math.sqrt(self.max_latent_dim))
extracted.append(l.reshape(b, c, sqdim, sqdim))
proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted
# flatten all the pixel projections