forked from mrq/DL-Art-School
Pixpro: Rather than using a latent square for pixpro, use an entirely stochastic sampling of the pixels
This commit is contained in:
parent
d1007ccfe7
commit
19475a072f
|
@ -332,7 +332,7 @@ class PixelCL(nn.Module):
|
|||
cutout_ratio_range = (0.6, 0.8),
|
||||
cutout_interpolate_mode = 'nearest',
|
||||
coord_cutout_interpolate_mode = 'bilinear',
|
||||
max_latent_dim = None # This is in latent space, not image space, so dimensionality reduction of your network must be accounted for.
|
||||
max_latent_dim = None # When set, this is the number of stochastically extracted pixels from the latent to extract. Must have an integer square root.
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -364,6 +364,9 @@ class PixelCL(nn.Module):
|
|||
self.distance_thres = distance_thres
|
||||
self.similarity_temperature = similarity_temperature
|
||||
self.alpha = alpha
|
||||
|
||||
# This requirement is due to the way that these are processed, not a hard requirement.
|
||||
assert math.sqrt(max_latent_dim) == int(math.sqrt(max_latent_dim))
|
||||
self.max_latent_dim = max_latent_dim
|
||||
|
||||
self.propagate_pixels = PPM(
|
||||
|
@ -429,9 +432,6 @@ class PixelCL(nn.Module):
|
|||
proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one,
|
||||
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
||||
image_one_cutout.shape, self.cutout_interpolate_mode)
|
||||
sim_region_img_one, sim_region_img_two = get_shared_region(image_one_cutout, image_two_cutout, cutout_coordinates_one,
|
||||
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
||||
image_one_cutout.shape, self.cutout_interpolate_mode)
|
||||
if proj_pixel_one is None or proj_pixel_two is None:
|
||||
positive_pixel_pairs = 0
|
||||
else:
|
||||
|
@ -445,25 +445,17 @@ class PixelCL(nn.Module):
|
|||
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
||||
image_one_cutout.shape, self.cutout_interpolate_mode)
|
||||
|
||||
# Apply max_latent_dim if needed.
|
||||
# If max_latent_dim is specified, stochastically extract latents from the shared areas.
|
||||
_, _, pp_h, pp_w = proj_pixel_one.shape
|
||||
if self.max_latent_dim and pp_h > self.max_latent_dim:
|
||||
margin = pp_h - self.max_latent_dim
|
||||
loc = random.randint(0, margin)
|
||||
loce = loc + self.max_latent_dim
|
||||
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, loc:loce, :], proj_pixel_two[:, :, loc:loce, :]
|
||||
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, loc:loce, :], target_proj_pixel_two[:, :, loc:loce, :]
|
||||
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, loc:loce, :], sim_region_img_two[:, :, loc:loce, :]
|
||||
if self.max_latent_dim and pp_w > self.max_latent_dim:
|
||||
margin = pp_w - self.max_latent_dim
|
||||
loc = random.randint(0, margin)
|
||||
loce = loc + self.max_latent_dim
|
||||
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, :, loc:loce], proj_pixel_two[:, :, :, loc:loce]
|
||||
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, :, loc:loce], target_proj_pixel_two[:, :, :, loc:loce]
|
||||
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, :, loc:loce], sim_region_img_two[:, :, :, loc:loce]
|
||||
# Stash these away for debugging purposes.
|
||||
self.sim_region_img_one = sim_region_img_one.detach().clone()
|
||||
self.sim_region_img_two = sim_region_img_two.detach().clone()
|
||||
if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim:
|
||||
prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim))
|
||||
latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two]
|
||||
extracted = []
|
||||
for l in latents:
|
||||
l = l.view(-1, -1, pp_h * pp_w)
|
||||
l = l[prob.multinomial(num_samples=self.max_latent_dim, replacement=False)]
|
||||
extracted.append(l.reshape(-1, -1, math.sqrt(self.max_latent_dim), math.sqrt(self.max_latent_dim)))
|
||||
proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted
|
||||
|
||||
# flatten all the pixel projections
|
||||
flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
|
||||
|
@ -503,8 +495,6 @@ class PixelCL(nn.Module):
|
|||
return
|
||||
torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,)))
|
||||
torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,)))
|
||||
torchvision.utils.save_image(self.sim_region_img_one, os.path.join(path, "%i_sim1.png" % (step,)))
|
||||
torchvision.utils.save_image(self.sim_region_img_two, os.path.join(path, "%i_sim2.png" % (step,)))
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
Loading…
Reference in New Issue
Block a user