Pixpro: Rather than using a latent square for pixpro, use an entirely stochastic sampling of the pixels
This commit is contained in:
parent
d1007ccfe7
commit
19475a072f
|
@ -332,7 +332,7 @@ class PixelCL(nn.Module):
|
||||||
cutout_ratio_range = (0.6, 0.8),
|
cutout_ratio_range = (0.6, 0.8),
|
||||||
cutout_interpolate_mode = 'nearest',
|
cutout_interpolate_mode = 'nearest',
|
||||||
coord_cutout_interpolate_mode = 'bilinear',
|
coord_cutout_interpolate_mode = 'bilinear',
|
||||||
max_latent_dim = None # This is in latent space, not image space, so dimensionality reduction of your network must be accounted for.
|
max_latent_dim = None # When set, this is the number of stochastically extracted pixels from the latent to extract. Must have an integer square root.
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -364,6 +364,9 @@ class PixelCL(nn.Module):
|
||||||
self.distance_thres = distance_thres
|
self.distance_thres = distance_thres
|
||||||
self.similarity_temperature = similarity_temperature
|
self.similarity_temperature = similarity_temperature
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|
||||||
|
# This requirement is due to the way that these are processed, not a hard requirement.
|
||||||
|
assert math.sqrt(max_latent_dim) == int(math.sqrt(max_latent_dim))
|
||||||
self.max_latent_dim = max_latent_dim
|
self.max_latent_dim = max_latent_dim
|
||||||
|
|
||||||
self.propagate_pixels = PPM(
|
self.propagate_pixels = PPM(
|
||||||
|
@ -429,9 +432,6 @@ class PixelCL(nn.Module):
|
||||||
proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one,
|
proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one,
|
||||||
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
||||||
image_one_cutout.shape, self.cutout_interpolate_mode)
|
image_one_cutout.shape, self.cutout_interpolate_mode)
|
||||||
sim_region_img_one, sim_region_img_two = get_shared_region(image_one_cutout, image_two_cutout, cutout_coordinates_one,
|
|
||||||
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
|
||||||
image_one_cutout.shape, self.cutout_interpolate_mode)
|
|
||||||
if proj_pixel_one is None or proj_pixel_two is None:
|
if proj_pixel_one is None or proj_pixel_two is None:
|
||||||
positive_pixel_pairs = 0
|
positive_pixel_pairs = 0
|
||||||
else:
|
else:
|
||||||
|
@ -445,25 +445,17 @@ class PixelCL(nn.Module):
|
||||||
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
||||||
image_one_cutout.shape, self.cutout_interpolate_mode)
|
image_one_cutout.shape, self.cutout_interpolate_mode)
|
||||||
|
|
||||||
# Apply max_latent_dim if needed.
|
# If max_latent_dim is specified, stochastically extract latents from the shared areas.
|
||||||
_, _, pp_h, pp_w = proj_pixel_one.shape
|
_, _, pp_h, pp_w = proj_pixel_one.shape
|
||||||
if self.max_latent_dim and pp_h > self.max_latent_dim:
|
if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim:
|
||||||
margin = pp_h - self.max_latent_dim
|
prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim))
|
||||||
loc = random.randint(0, margin)
|
latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two]
|
||||||
loce = loc + self.max_latent_dim
|
extracted = []
|
||||||
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, loc:loce, :], proj_pixel_two[:, :, loc:loce, :]
|
for l in latents:
|
||||||
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, loc:loce, :], target_proj_pixel_two[:, :, loc:loce, :]
|
l = l.view(-1, -1, pp_h * pp_w)
|
||||||
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, loc:loce, :], sim_region_img_two[:, :, loc:loce, :]
|
l = l[prob.multinomial(num_samples=self.max_latent_dim, replacement=False)]
|
||||||
if self.max_latent_dim and pp_w > self.max_latent_dim:
|
extracted.append(l.reshape(-1, -1, math.sqrt(self.max_latent_dim), math.sqrt(self.max_latent_dim)))
|
||||||
margin = pp_w - self.max_latent_dim
|
proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted
|
||||||
loc = random.randint(0, margin)
|
|
||||||
loce = loc + self.max_latent_dim
|
|
||||||
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, :, loc:loce], proj_pixel_two[:, :, :, loc:loce]
|
|
||||||
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, :, loc:loce], target_proj_pixel_two[:, :, :, loc:loce]
|
|
||||||
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, :, loc:loce], sim_region_img_two[:, :, :, loc:loce]
|
|
||||||
# Stash these away for debugging purposes.
|
|
||||||
self.sim_region_img_one = sim_region_img_one.detach().clone()
|
|
||||||
self.sim_region_img_two = sim_region_img_two.detach().clone()
|
|
||||||
|
|
||||||
# flatten all the pixel projections
|
# flatten all the pixel projections
|
||||||
flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
|
flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
|
||||||
|
@ -503,8 +495,6 @@ class PixelCL(nn.Module):
|
||||||
return
|
return
|
||||||
torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,)))
|
torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,)))
|
||||||
torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,)))
|
torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,)))
|
||||||
torchvision.utils.save_image(self.sim_region_img_one, os.path.join(path, "%i_sim1.png" % (step,)))
|
|
||||||
torchvision.utils.save_image(self.sim_region_img_two, os.path.join(path, "%i_sim2.png" % (step,)))
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user