Pixpro: Rather than using a latent square for pixpro, use an entirely stochastic sampling of the pixels

This commit is contained in:
James Betker 2021-01-13 11:26:51 -07:00
parent d1007ccfe7
commit 19475a072f

View File

@ -332,7 +332,7 @@ class PixelCL(nn.Module):
cutout_ratio_range = (0.6, 0.8),
cutout_interpolate_mode = 'nearest',
coord_cutout_interpolate_mode = 'bilinear',
max_latent_dim = None # This is in latent space, not image space, so dimensionality reduction of your network must be accounted for.
max_latent_dim = None # When set, this is the number of stochastically extracted pixels from the latent to extract. Must have an integer square root.
):
super().__init__()
@ -364,6 +364,9 @@ class PixelCL(nn.Module):
self.distance_thres = distance_thres
self.similarity_temperature = similarity_temperature
self.alpha = alpha
# This requirement is due to the way that these are processed, not a hard requirement.
assert math.sqrt(max_latent_dim) == int(math.sqrt(max_latent_dim))
self.max_latent_dim = max_latent_dim
self.propagate_pixels = PPM(
@ -429,9 +432,6 @@ class PixelCL(nn.Module):
proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one,
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
image_one_cutout.shape, self.cutout_interpolate_mode)
sim_region_img_one, sim_region_img_two = get_shared_region(image_one_cutout, image_two_cutout, cutout_coordinates_one,
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
image_one_cutout.shape, self.cutout_interpolate_mode)
if proj_pixel_one is None or proj_pixel_two is None:
positive_pixel_pairs = 0
else:
@ -445,25 +445,17 @@ class PixelCL(nn.Module):
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
image_one_cutout.shape, self.cutout_interpolate_mode)
# Apply max_latent_dim if needed.
# If max_latent_dim is specified, stochastically extract latents from the shared areas.
_, _, pp_h, pp_w = proj_pixel_one.shape
if self.max_latent_dim and pp_h > self.max_latent_dim:
margin = pp_h - self.max_latent_dim
loc = random.randint(0, margin)
loce = loc + self.max_latent_dim
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, loc:loce, :], proj_pixel_two[:, :, loc:loce, :]
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, loc:loce, :], target_proj_pixel_two[:, :, loc:loce, :]
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, loc:loce, :], sim_region_img_two[:, :, loc:loce, :]
if self.max_latent_dim and pp_w > self.max_latent_dim:
margin = pp_w - self.max_latent_dim
loc = random.randint(0, margin)
loce = loc + self.max_latent_dim
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, :, loc:loce], proj_pixel_two[:, :, :, loc:loce]
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, :, loc:loce], target_proj_pixel_two[:, :, :, loc:loce]
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, :, loc:loce], sim_region_img_two[:, :, :, loc:loce]
# Stash these away for debugging purposes.
self.sim_region_img_one = sim_region_img_one.detach().clone()
self.sim_region_img_two = sim_region_img_two.detach().clone()
if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim:
prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim))
latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two]
extracted = []
for l in latents:
l = l.view(-1, -1, pp_h * pp_w)
l = l[prob.multinomial(num_samples=self.max_latent_dim, replacement=False)]
extracted.append(l.reshape(-1, -1, math.sqrt(self.max_latent_dim), math.sqrt(self.max_latent_dim)))
proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted
# flatten all the pixel projections
flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
@ -503,8 +495,6 @@ class PixelCL(nn.Module):
return
torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,)))
torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,)))
torchvision.utils.save_image(self.sim_region_img_one, os.path.join(path, "%i_sim1.png" % (step,)))
torchvision.utils.save_image(self.sim_region_img_two, os.path.join(path, "%i_sim2.png" % (step,)))
@register_model