forked from mrq/DL-Art-School
Adjustments to pixpro to allow training against networks with arbitrarily large structural latents
- The pixpro latent now rescales the latent space instead of using a "coordinate vector", which **might** have performance implications. - The latent against which the pixel loss is computed can now be a small, randomly sampled patch out of the entire latent, allowing further memory/computational discounts. Since the loss computation does not have a receptive field, this should not alter the loss. - The instance projection size can now be separate from the pixel projection size. - PixContrast removed entirely. - ResUnet with full resolution added.
This commit is contained in:
parent
34f8c8641f
commit
d1007ccfe7
|
@ -66,6 +66,59 @@ def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'):
|
||||||
cutout_image = image[:, :, y0:y1, x0:x1]
|
cutout_image = image[:, :, y0:y1, x0:x1]
|
||||||
return F.interpolate(cutout_image, size = output_size, mode = mode)
|
return F.interpolate(cutout_image, size = output_size, mode = mode)
|
||||||
|
|
||||||
|
def scale_coords(coords, scale):
|
||||||
|
output = [[0,0],[0,0]]
|
||||||
|
for j in range(2):
|
||||||
|
for k in range(2):
|
||||||
|
output[j][k] = int(coords[j][k] / scale)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def reverse_cutout_and_resize(image, coordinates, scale_reduction, mode = 'nearest'):
|
||||||
|
blank = torch.zeros_like(image)
|
||||||
|
coordinates = scale_coords(coordinates, scale_reduction)
|
||||||
|
(y0, y1), (x0, x1) = coordinates
|
||||||
|
orig_cutout_shape = (y1-y0, x1-x0)
|
||||||
|
if orig_cutout_shape[0] <= 0 or orig_cutout_shape[1] <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
un_resized_img = F.interpolate(image, size=orig_cutout_shape, mode=mode)
|
||||||
|
blank[:,:,y0:y1,x0:x1] = un_resized_img
|
||||||
|
return blank
|
||||||
|
|
||||||
|
def compute_shared_coords(coords1, coords2, scale_reduction):
|
||||||
|
(y1_t, y1_b), (x1_l, x1_r) = scale_coords(coords1, scale_reduction)
|
||||||
|
(y2_t, y2_b), (x2_l, x2_r) = scale_coords(coords2, scale_reduction)
|
||||||
|
shared = ((max(y1_t, y2_t), min(y1_b, y2_b)),
|
||||||
|
(max(x1_l, x2_l), min(x1_r, x2_r)))
|
||||||
|
for s in shared:
|
||||||
|
if s == 0:
|
||||||
|
return None
|
||||||
|
return shared
|
||||||
|
|
||||||
|
def get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, img_orig_shape, interp_mode):
|
||||||
|
# Unflip the pixel projections
|
||||||
|
proj_pixel_one = flip_image_one_fn(proj_pixel_one)
|
||||||
|
proj_pixel_two = flip_image_two_fn(proj_pixel_two)
|
||||||
|
|
||||||
|
# Undo the cutout and resize, taking into account the scale reduction applied by the encoder.
|
||||||
|
scale_reduction = proj_pixel_one.shape[-1] / img_orig_shape[-1]
|
||||||
|
proj_pixel_one = reverse_cutout_and_resize(proj_pixel_one, cutout_coordinates_one, scale_reduction,
|
||||||
|
mode=interp_mode)
|
||||||
|
proj_pixel_two = reverse_cutout_and_resize(proj_pixel_two, cutout_coordinates_two, scale_reduction,
|
||||||
|
mode=interp_mode)
|
||||||
|
if proj_pixel_one is None or proj_pixel_two is None:
|
||||||
|
print("Could not extract projected image region. The selected cutout coordinates were smaller than the aggregate size of one latent block!")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Compute the shared coordinates for the two cutouts:
|
||||||
|
shared_coords = compute_shared_coords(cutout_coordinates_one, cutout_coordinates_two, scale_reduction)
|
||||||
|
if shared_coords is None:
|
||||||
|
print("No shared coordinates for this iteration (probably should just recompute those coordinates earlier..")
|
||||||
|
return None
|
||||||
|
(yt, yb), (xl, xr) = shared_coords
|
||||||
|
|
||||||
|
return proj_pixel_one[:, :, yt:yb, xl:xr], proj_pixel_two[:, :, yt:yb, xl:xr]
|
||||||
|
|
||||||
# augmentation utils
|
# augmentation utils
|
||||||
|
|
||||||
class RandomApply(nn.Module):
|
class RandomApply(nn.Module):
|
||||||
|
@ -172,8 +225,10 @@ class NetWrapper(nn.Module):
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
net,
|
net,
|
||||||
projection_size,
|
instance_projection_size,
|
||||||
projection_hidden_size,
|
instance_projection_hidden_size,
|
||||||
|
pix_projection_size,
|
||||||
|
pix_projection_hidden_size,
|
||||||
layer_pixel = -2,
|
layer_pixel = -2,
|
||||||
layer_instance = -2
|
layer_instance = -2
|
||||||
):
|
):
|
||||||
|
@ -185,8 +240,10 @@ class NetWrapper(nn.Module):
|
||||||
self.pixel_projector = None
|
self.pixel_projector = None
|
||||||
self.instance_projector = None
|
self.instance_projector = None
|
||||||
|
|
||||||
self.projection_size = projection_size
|
self.instance_projection_size = instance_projection_size
|
||||||
self.projection_hidden_size = projection_hidden_size
|
self.instance_projection_hidden_size = instance_projection_hidden_size
|
||||||
|
self.pix_projection_size = pix_projection_size
|
||||||
|
self.pix_projection_hidden_size = pix_projection_hidden_size
|
||||||
|
|
||||||
self.hidden_pixel = None
|
self.hidden_pixel = None
|
||||||
self.hidden_instance = None
|
self.hidden_instance = None
|
||||||
|
@ -218,13 +275,13 @@ class NetWrapper(nn.Module):
|
||||||
@singleton('pixel_projector')
|
@singleton('pixel_projector')
|
||||||
def _get_pixel_projector(self, hidden):
|
def _get_pixel_projector(self, hidden):
|
||||||
_, dim, *_ = hidden.shape
|
_, dim, *_ = hidden.shape
|
||||||
projector = ConvMLP(dim, self.projection_size, self.projection_hidden_size)
|
projector = ConvMLP(dim, self.pix_projection_size, self.pix_projection_hidden_size)
|
||||||
return projector.to(hidden)
|
return projector.to(hidden)
|
||||||
|
|
||||||
@singleton('instance_projector')
|
@singleton('instance_projector')
|
||||||
def _get_instance_projector(self, hidden):
|
def _get_instance_projector(self, hidden):
|
||||||
_, dim = hidden.shape
|
_, dim = hidden.shape
|
||||||
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
|
projector = MLP(dim, self.instance_projection_size, self.instance_projection_hidden_size)
|
||||||
return projector.to(hidden)
|
return projector.to(hidden)
|
||||||
|
|
||||||
def get_representation(self, x):
|
def get_representation(self, x):
|
||||||
|
@ -252,7 +309,6 @@ class NetWrapper(nn.Module):
|
||||||
return pixel_projection, instance_projection
|
return pixel_projection, instance_projection
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|
||||||
class PixelCL(nn.Module):
|
class PixelCL(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -260,8 +316,10 @@ class PixelCL(nn.Module):
|
||||||
image_size,
|
image_size,
|
||||||
hidden_layer_pixel = -2,
|
hidden_layer_pixel = -2,
|
||||||
hidden_layer_instance = -2,
|
hidden_layer_instance = -2,
|
||||||
projection_size = 256,
|
instance_projection_size = 256,
|
||||||
projection_hidden_size = 2048,
|
instance_projection_hidden_size = 2048,
|
||||||
|
pix_projection_size = 256,
|
||||||
|
pix_projection_hidden_size = 2048,
|
||||||
augment_fn = None,
|
augment_fn = None,
|
||||||
augment_fn2 = None,
|
augment_fn2 = None,
|
||||||
prob_rand_hflip = 0.25,
|
prob_rand_hflip = 0.25,
|
||||||
|
@ -271,10 +329,10 @@ class PixelCL(nn.Module):
|
||||||
distance_thres = 0.7,
|
distance_thres = 0.7,
|
||||||
similarity_temperature = 0.3,
|
similarity_temperature = 0.3,
|
||||||
alpha = 1.,
|
alpha = 1.,
|
||||||
use_pixpro = True,
|
|
||||||
cutout_ratio_range = (0.6, 0.8),
|
cutout_ratio_range = (0.6, 0.8),
|
||||||
cutout_interpolate_mode = 'nearest',
|
cutout_interpolate_mode = 'nearest',
|
||||||
coord_cutout_interpolate_mode = 'bilinear'
|
coord_cutout_interpolate_mode = 'bilinear',
|
||||||
|
max_latent_dim = None # This is in latent space, not image space, so dimensionality reduction of your network must be accounted for.
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -292,8 +350,10 @@ class PixelCL(nn.Module):
|
||||||
|
|
||||||
self.online_encoder = NetWrapper(
|
self.online_encoder = NetWrapper(
|
||||||
net = net,
|
net = net,
|
||||||
projection_size = projection_size,
|
instance_projection_size = instance_projection_size,
|
||||||
projection_hidden_size = projection_hidden_size,
|
instance_projection_hidden_size = instance_projection_hidden_size,
|
||||||
|
pix_projection_size = pix_projection_size,
|
||||||
|
pix_projection_hidden_size = pix_projection_hidden_size,
|
||||||
layer_pixel = hidden_layer_pixel,
|
layer_pixel = hidden_layer_pixel,
|
||||||
layer_instance = hidden_layer_instance
|
layer_instance = hidden_layer_instance
|
||||||
)
|
)
|
||||||
|
@ -304,22 +364,20 @@ class PixelCL(nn.Module):
|
||||||
self.distance_thres = distance_thres
|
self.distance_thres = distance_thres
|
||||||
self.similarity_temperature = similarity_temperature
|
self.similarity_temperature = similarity_temperature
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
self.max_latent_dim = max_latent_dim
|
||||||
|
|
||||||
self.use_pixpro = use_pixpro
|
self.propagate_pixels = PPM(
|
||||||
|
chan = pix_projection_size,
|
||||||
if use_pixpro:
|
num_layers = ppm_num_layers,
|
||||||
self.propagate_pixels = PPM(
|
gamma = ppm_gamma
|
||||||
chan = projection_size,
|
)
|
||||||
num_layers = ppm_num_layers,
|
|
||||||
gamma = ppm_gamma
|
|
||||||
)
|
|
||||||
|
|
||||||
self.cutout_ratio_range = cutout_ratio_range
|
self.cutout_ratio_range = cutout_ratio_range
|
||||||
self.cutout_interpolate_mode = cutout_interpolate_mode
|
self.cutout_interpolate_mode = cutout_interpolate_mode
|
||||||
self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode
|
self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode
|
||||||
|
|
||||||
# instance level predictor
|
# instance level predictor
|
||||||
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
|
self.online_predictor = MLP(instance_projection_size, instance_projection_size, instance_projection_hidden_size)
|
||||||
|
|
||||||
# get device of network and make wrapper same device
|
# get device of network and make wrapper same device
|
||||||
device = get_module_device(net)
|
device = get_module_device(net)
|
||||||
|
@ -368,106 +426,74 @@ class PixelCL(nn.Module):
|
||||||
proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout)
|
proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout)
|
||||||
proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout)
|
proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout)
|
||||||
|
|
||||||
image_h, image_w = shape[2:]
|
proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one,
|
||||||
|
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
||||||
proj_image_shape = proj_pixel_one.shape[2:]
|
image_one_cutout.shape, self.cutout_interpolate_mode)
|
||||||
proj_image_h, proj_image_w = proj_image_shape
|
sim_region_img_one, sim_region_img_two = get_shared_region(image_one_cutout, image_two_cutout, cutout_coordinates_one,
|
||||||
|
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
||||||
coordinates = torch.meshgrid(
|
image_one_cutout.shape, self.cutout_interpolate_mode)
|
||||||
torch.arange(image_h, device = device),
|
if proj_pixel_one is None or proj_pixel_two is None:
|
||||||
torch.arange(image_w, device = device)
|
positive_pixel_pairs = 0
|
||||||
)
|
else:
|
||||||
|
positive_pixel_pairs = proj_pixel_one.shape[-1] * proj_pixel_one.shape[-2]
|
||||||
coordinates = torch.stack(coordinates).unsqueeze(0).float()
|
|
||||||
coordinates /= math.sqrt(image_h ** 2 + image_w ** 2)
|
|
||||||
coordinates[:, 0] *= proj_image_h
|
|
||||||
coordinates[:, 1] *= proj_image_w
|
|
||||||
|
|
||||||
proj_coors_one = cutout_and_resize(coordinates, cutout_coordinates_one, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode)
|
|
||||||
proj_coors_two = cutout_and_resize(coordinates, cutout_coordinates_two, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode)
|
|
||||||
|
|
||||||
proj_coors_one = flip_image_one_fn(proj_coors_one)
|
|
||||||
proj_coors_two = flip_image_two_fn(proj_coors_two)
|
|
||||||
|
|
||||||
proj_coors_one, proj_coors_two = map(lambda t: rearrange(t, 'b c h w -> (b h w) c'), (proj_coors_one, proj_coors_two))
|
|
||||||
pdist = nn.PairwiseDistance(p = 2)
|
|
||||||
|
|
||||||
num_pixels = proj_coors_one.shape[0]
|
|
||||||
|
|
||||||
proj_coors_one_expanded = proj_coors_one[:, None].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)
|
|
||||||
proj_coors_two_expanded = proj_coors_two[None, :].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)
|
|
||||||
|
|
||||||
distance_matrix = pdist(proj_coors_one_expanded, proj_coors_two_expanded)
|
|
||||||
distance_matrix = distance_matrix.reshape(num_pixels, num_pixels)
|
|
||||||
|
|
||||||
positive_mask_one_two = distance_matrix < self.distance_thres
|
|
||||||
positive_mask_two_one = positive_mask_one_two.t()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_encoder = self._get_target_encoder()
|
target_encoder = self._get_target_encoder()
|
||||||
target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout)
|
target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout)
|
||||||
target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout)
|
target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout)
|
||||||
|
target_proj_pixel_one, target_proj_pixel_two = get_shared_region(target_proj_pixel_one, target_proj_pixel_two, cutout_coordinates_one,
|
||||||
|
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
|
||||||
|
image_one_cutout.shape, self.cutout_interpolate_mode)
|
||||||
|
|
||||||
|
# Apply max_latent_dim if needed.
|
||||||
|
_, _, pp_h, pp_w = proj_pixel_one.shape
|
||||||
|
if self.max_latent_dim and pp_h > self.max_latent_dim:
|
||||||
|
margin = pp_h - self.max_latent_dim
|
||||||
|
loc = random.randint(0, margin)
|
||||||
|
loce = loc + self.max_latent_dim
|
||||||
|
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, loc:loce, :], proj_pixel_two[:, :, loc:loce, :]
|
||||||
|
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, loc:loce, :], target_proj_pixel_two[:, :, loc:loce, :]
|
||||||
|
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, loc:loce, :], sim_region_img_two[:, :, loc:loce, :]
|
||||||
|
if self.max_latent_dim and pp_w > self.max_latent_dim:
|
||||||
|
margin = pp_w - self.max_latent_dim
|
||||||
|
loc = random.randint(0, margin)
|
||||||
|
loce = loc + self.max_latent_dim
|
||||||
|
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, :, loc:loce], proj_pixel_two[:, :, :, loc:loce]
|
||||||
|
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, :, loc:loce], target_proj_pixel_two[:, :, :, loc:loce]
|
||||||
|
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, :, loc:loce], sim_region_img_two[:, :, :, loc:loce]
|
||||||
|
# Stash these away for debugging purposes.
|
||||||
|
self.sim_region_img_one = sim_region_img_one.detach().clone()
|
||||||
|
self.sim_region_img_two = sim_region_img_two.detach().clone()
|
||||||
|
|
||||||
# flatten all the pixel projections
|
# flatten all the pixel projections
|
||||||
|
|
||||||
flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
|
flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
|
||||||
|
|
||||||
target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two)))
|
target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two)))
|
||||||
|
|
||||||
# get total number of positive pixel pairs
|
|
||||||
|
|
||||||
positive_pixel_pairs = positive_mask_one_two.sum()
|
|
||||||
|
|
||||||
# get instance level loss
|
# get instance level loss
|
||||||
|
|
||||||
pred_instance_one = self.online_predictor(proj_instance_one)
|
pred_instance_one = self.online_predictor(proj_instance_one)
|
||||||
pred_instance_two = self.online_predictor(proj_instance_two)
|
pred_instance_two = self.online_predictor(proj_instance_two)
|
||||||
|
|
||||||
loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach())
|
loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach())
|
||||||
loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach())
|
loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach())
|
||||||
|
|
||||||
instance_loss = (loss_instance_one + loss_instance_two).mean()
|
instance_loss = (loss_instance_one + loss_instance_two).mean()
|
||||||
|
|
||||||
if positive_pixel_pairs == 0:
|
if positive_pixel_pairs == 0:
|
||||||
return instance_loss, 0
|
return instance_loss, 0
|
||||||
|
|
||||||
if not self.use_pixpro:
|
# calculate pix pro loss
|
||||||
# calculate pix contrast loss
|
propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
|
||||||
|
propagated_pixels_two = self.propagate_pixels(proj_pixel_two)
|
||||||
|
|
||||||
proj_pixel_one, proj_pixel_two = list(map(flatten, (proj_pixel_one, proj_pixel_two)))
|
propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two)))
|
||||||
|
|
||||||
similarity_one_two = F.cosine_similarity(proj_pixel_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) / self.similarity_temperature
|
propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1)
|
||||||
similarity_two_one = F.cosine_similarity(proj_pixel_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) / self.similarity_temperature
|
propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1)
|
||||||
|
|
||||||
loss_pix_one_two = -torch.log(
|
loss_pixpro_one_two = - propagated_similarity_one_two.mean()
|
||||||
similarity_one_two.masked_select(positive_mask_one_two[None, ...]).exp().sum() /
|
loss_pixpro_two_one = - propagated_similarity_two_one.mean()
|
||||||
similarity_one_two.exp().sum()
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_pix_two_one = -torch.log(
|
pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2
|
||||||
similarity_two_one.masked_select(positive_mask_two_one[None, ...]).exp().sum() /
|
|
||||||
similarity_two_one.exp().sum()
|
|
||||||
)
|
|
||||||
|
|
||||||
pix_loss = (loss_pix_one_two + loss_pix_two_one) / 2
|
|
||||||
else:
|
|
||||||
# calculate pix pro loss
|
|
||||||
|
|
||||||
propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
|
|
||||||
propagated_pixels_two = self.propagate_pixels(proj_pixel_two)
|
|
||||||
|
|
||||||
propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two)))
|
|
||||||
|
|
||||||
propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1)
|
|
||||||
propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1)
|
|
||||||
|
|
||||||
loss_pixpro_one_two = - propagated_similarity_one_two.masked_select(positive_mask_one_two[None, ...]).mean()
|
|
||||||
loss_pixpro_two_one = - propagated_similarity_two_one.masked_select(positive_mask_two_one[None, ...]).mean()
|
|
||||||
|
|
||||||
pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2
|
|
||||||
|
|
||||||
# total loss
|
# total loss
|
||||||
|
|
||||||
loss = pix_loss * self.alpha + instance_loss
|
loss = pix_loss * self.alpha + instance_loss
|
||||||
return loss, positive_pixel_pairs
|
return loss, positive_pixel_pairs
|
||||||
|
|
||||||
|
@ -477,6 +503,8 @@ class PixelCL(nn.Module):
|
||||||
return
|
return
|
||||||
torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,)))
|
torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,)))
|
||||||
torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,)))
|
torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,)))
|
||||||
|
torchvision.utils.save_image(self.sim_region_img_one, os.path.join(path, "%i_sim1.png" % (step,)))
|
||||||
|
torchvision.utils.save_image(self.sim_region_img_two, os.path.join(path, "%i_sim2.png" % (step,)))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
|
@ -1,79 +1,16 @@
|
||||||
# Resnet implementation that adds a u-net style up-conversion component to output values at a
|
|
||||||
# specified pixel density.
|
|
||||||
#
|
|
||||||
# The downsampling part of the network is compatible with the built-in torch resnet for use in
|
|
||||||
# transfer learning.
|
|
||||||
#
|
|
||||||
# Only resnet50 currently supported.
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
|
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
|
||||||
from torchvision.models.utils import load_state_dict_from_url
|
from torchvision.models.utils import load_state_dict_from_url
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
|
from models.arch_util import ConvBnRelu
|
||||||
|
from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, opt_get
|
from utils.util import checkpoint, opt_get
|
||||||
|
|
||||||
|
|
||||||
class ReverseBottleneck(nn.Module):
|
class UResNet50_2(torchvision.models.resnet.ResNet):
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, groups=1, passthrough=False,
|
|
||||||
base_width=64, dilation=1, norm_layer=None):
|
|
||||||
super().__init__()
|
|
||||||
if norm_layer is None:
|
|
||||||
norm_layer = nn.BatchNorm2d
|
|
||||||
width = int(planes * (base_width / 64.)) * groups
|
|
||||||
self.passthrough = passthrough
|
|
||||||
if passthrough:
|
|
||||||
self.integrate = conv1x1(inplanes*2, inplanes)
|
|
||||||
self.bn_integrate = norm_layer(inplanes)
|
|
||||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
|
||||||
self.conv1 = conv1x1(inplanes, width)
|
|
||||||
self.bn1 = norm_layer(width)
|
|
||||||
self.conv2 = conv3x3(width, width, groups, dilation)
|
|
||||||
self.bn2 = norm_layer(width)
|
|
||||||
self.residual_upsample = nn.Sequential(
|
|
||||||
nn.Upsample(scale_factor=2, mode='nearest'),
|
|
||||||
conv1x1(width, width),
|
|
||||||
norm_layer(width),
|
|
||||||
)
|
|
||||||
self.conv3 = conv1x1(width, planes)
|
|
||||||
self.bn3 = norm_layer(planes)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.upsample = nn.Sequential(
|
|
||||||
nn.Upsample(scale_factor=2, mode='nearest'),
|
|
||||||
conv1x1(inplanes, planes),
|
|
||||||
norm_layer(planes),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, passthrough=None):
|
|
||||||
if self.passthrough:
|
|
||||||
x = self.bn_integrate(self.integrate(torch.cat([x, passthrough], dim=1)))
|
|
||||||
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.residual_upsample(out)
|
|
||||||
|
|
||||||
out = self.conv3(out)
|
|
||||||
out = self.bn3(out)
|
|
||||||
|
|
||||||
identity = self.upsample(x)
|
|
||||||
|
|
||||||
out = out + identity
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class UResNet50(torchvision.models.resnet.ResNet):
|
|
||||||
|
|
||||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
||||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||||
|
@ -82,6 +19,7 @@ class UResNet50(torchvision.models.resnet.ResNet):
|
||||||
replace_stride_with_dilation, norm_layer)
|
replace_stride_with_dilation, norm_layer)
|
||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
norm_layer = nn.BatchNorm2d
|
norm_layer = nn.BatchNorm2d
|
||||||
|
self.level_conv = ConvBnRelu(3, 64)
|
||||||
'''
|
'''
|
||||||
# For reference:
|
# For reference:
|
||||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||||
|
@ -95,29 +33,24 @@ class UResNet50(torchvision.models.resnet.ResNet):
|
||||||
uplayers = []
|
uplayers = []
|
||||||
inplanes = 2048
|
inplanes = 2048
|
||||||
first = True
|
first = True
|
||||||
for i in range(2):
|
div = [2,2,2,4,1]
|
||||||
uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first))
|
for i in range(5):
|
||||||
inplanes = inplanes // 2
|
uplayers.append(ReverseBottleneck(inplanes, inplanes // div[i], norm_layer=norm_layer, passthrough=not first))
|
||||||
|
inplanes = inplanes // div[i]
|
||||||
first = False
|
first = False
|
||||||
self.uplayers = nn.ModuleList(uplayers)
|
self.uplayers = nn.ModuleList(uplayers)
|
||||||
self.tail = nn.Sequential(conv1x1(1024, 512),
|
self.tail = nn.Sequential(conv3x3(128, 64),
|
||||||
norm_layer(512),
|
norm_layer(64),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
conv3x3(512, 512),
|
conv1x1(64, out_dim))
|
||||||
norm_layer(512),
|
|
||||||
nn.ReLU(),
|
|
||||||
conv1x1(512, out_dim))
|
|
||||||
|
|
||||||
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
|
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
|
||||||
|
|
||||||
|
|
||||||
def _forward_impl(self, x):
|
def _forward_impl(self, x):
|
||||||
# Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl,
|
level = self.level_conv(x)
|
||||||
# except using checkpoints on the body conv layers.
|
x0 = self.relu(self.bn1(self.conv1(x)))
|
||||||
x = self.conv1(x)
|
x = self.maxpool(x0)
|
||||||
x = self.bn1(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.maxpool(x)
|
|
||||||
|
|
||||||
x1 = checkpoint(self.layer1, x)
|
x1 = checkpoint(self.layer1, x)
|
||||||
x2 = checkpoint(self.layer2, x1)
|
x2 = checkpoint(self.layer2, x1)
|
||||||
|
@ -127,18 +60,19 @@ class UResNet50(torchvision.models.resnet.ResNet):
|
||||||
|
|
||||||
x = checkpoint(self.uplayers[0], x4)
|
x = checkpoint(self.uplayers[0], x4)
|
||||||
x = checkpoint(self.uplayers[1], x, x3)
|
x = checkpoint(self.uplayers[1], x, x3)
|
||||||
#x = checkpoint(self.uplayers[2], x, x2)
|
x = checkpoint(self.uplayers[2], x, x2)
|
||||||
#x = checkpoint(self.uplayers[3], x, x1)
|
x = checkpoint(self.uplayers[3], x, x1)
|
||||||
|
x = checkpoint(self.uplayers[4], x, x0)
|
||||||
|
|
||||||
return checkpoint(self.tail, torch.cat([x, x2], dim=1))
|
return checkpoint(self.tail, torch.cat([x, level], dim=1))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self._forward_impl(x)
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_u_resnet50(opt_net, opt):
|
def register_u_resnet50_2(opt_net, opt):
|
||||||
model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
model = UResNet50_2(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
||||||
if opt_get(opt_net, ['use_pretrained_base'], False):
|
if opt_get(opt_net, ['use_pretrained_base'], False):
|
||||||
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
|
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
@ -146,7 +80,8 @@ def register_u_resnet50(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = UResNet50(Bottleneck, [3,4,6,3])
|
model = UResNet50_2(Bottleneck, [3,4,6,3])
|
||||||
samp = torch.rand(1,3,224,224)
|
samp = torch.rand(1,3,224,224)
|
||||||
model(samp)
|
y = model(samp)
|
||||||
|
print(y.shape)
|
||||||
# For pixpro: attach to "tail.3"
|
# For pixpro: attach to "tail.3"
|
||||||
|
|
|
@ -14,16 +14,16 @@ def main():
|
||||||
split_img = False
|
split_img = False
|
||||||
opt = {}
|
opt = {}
|
||||||
opt['n_thread'] = 7
|
opt['n_thread'] = 7
|
||||||
opt['compression_level'] = 95 # JPEG compression quality rating.
|
opt['compression_level'] = 90 # JPEG compression quality rating.
|
||||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
|
|
||||||
opt['dest'] = 'file'
|
opt['dest'] = 'file'
|
||||||
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\working']
|
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\working']
|
||||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped'
|
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised'
|
||||||
opt['imgsize'] = 1024
|
opt['imgsize'] = 256
|
||||||
opt['bottom_crop'] = .1
|
opt['bottom_crop'] = 0.1
|
||||||
opt['keep_folder'] = True
|
opt['keep_folder'] = False
|
||||||
|
|
||||||
save_folder = opt['save_folder']
|
save_folder = opt['save_folder']
|
||||||
if not osp.exists(save_folder):
|
if not osp.exists(save_folder):
|
||||||
|
@ -58,7 +58,7 @@ class TiledDataset(data.Dataset):
|
||||||
|
|
||||||
# Perform explicit crops first. These are generally used to get rid of watermarks so we dont even want to
|
# Perform explicit crops first. These are generally used to get rid of watermarks so we dont even want to
|
||||||
# consider these areas of the image.
|
# consider these areas of the image.
|
||||||
if 'bottom_crop' in self.opt.keys():
|
if 'bottom_crop' in self.opt.keys() and self.opt['bottom_crop'] > 0:
|
||||||
bc = self.opt['bottom_crop']
|
bc = self.opt['bottom_crop']
|
||||||
if bc > 0 and bc < 1:
|
if bc > 0 and bc < 1:
|
||||||
bc = int(bc * img.shape[0])
|
bc = int(bc * img.shape[0])
|
||||||
|
@ -83,9 +83,7 @@ class TiledDataset(data.Dataset):
|
||||||
pts = os.path.split(pts[0])
|
pts = os.path.split(pts[0])
|
||||||
output_folder = osp.join(self.opt['save_folder'], pts[-1])
|
output_folder = osp.join(self.opt['save_folder'], pts[-1])
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
if not basename.endswith(".jpg"):
|
cv2.imwrite(osp.join(output_folder, basename), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
||||||
basename = basename + ".jpg"
|
|
||||||
cv2.imwrite(osp.join(output_folder, basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -295,7 +295,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_pixpro_resnet.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_resnet50.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user