Fix pixpro stochastic sampling bugs
This commit is contained in:
parent
19475a072f
commit
8990801a3f
|
@ -446,15 +446,19 @@ class PixelCL(nn.Module):
|
|||
image_one_cutout.shape, self.cutout_interpolate_mode)
|
||||
|
||||
# If max_latent_dim is specified, stochastically extract latents from the shared areas.
|
||||
_, _, pp_h, pp_w = proj_pixel_one.shape
|
||||
b, c, pp_h, pp_w = proj_pixel_one.shape
|
||||
if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim:
|
||||
prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim))
|
||||
latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two]
|
||||
extracted = []
|
||||
for l in latents:
|
||||
l = l.view(-1, -1, pp_h * pp_w)
|
||||
l = l[prob.multinomial(num_samples=self.max_latent_dim, replacement=False)]
|
||||
extracted.append(l.reshape(-1, -1, math.sqrt(self.max_latent_dim), math.sqrt(self.max_latent_dim)))
|
||||
l = l.reshape(b, c, pp_h * pp_w)
|
||||
l = l[:, :, prob.multinomial(num_samples=self.max_latent_dim, replacement=False)]
|
||||
# For compatibility with the existing pixpro code, reshape this stochastic sampling back into a 2d "square".
|
||||
# Note that the actual structure no longer matters going forwards. Pixels are only compared to themselves and others without regards
|
||||
# to structure.
|
||||
sqdim = int(math.sqrt(self.max_latent_dim))
|
||||
extracted.append(l.reshape(b, c, sqdim, sqdim))
|
||||
proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted
|
||||
|
||||
# flatten all the pixel projections
|
||||
|
|
Loading…
Reference in New Issue
Block a user