Adjustments to pixpro to allow training against networks with arbitrarily large structural latents

- The pixpro latent now rescales the latent space instead of using a "coordinate vector", which
   **might** have performance implications.
- The latent against which the pixel loss is computed can now be a small, randomly sampled patch
   out of the entire latent, allowing further memory/computational discounts. Since the loss
   computation does not have a receptive field, this should not alter the loss.
- The instance projection size can now be separate from the pixel projection size.
- PixContrast removed entirely.
- ResUnet with full resolution added.
This commit is contained in:
James Betker 2021-01-12 09:17:45 -07:00
parent 34f8c8641f
commit d1007ccfe7
4 changed files with 155 additions and 194 deletions

View File

@ -66,6 +66,59 @@ def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'):
cutout_image = image[:, :, y0:y1, x0:x1] cutout_image = image[:, :, y0:y1, x0:x1]
return F.interpolate(cutout_image, size = output_size, mode = mode) return F.interpolate(cutout_image, size = output_size, mode = mode)
def scale_coords(coords, scale):
output = [[0,0],[0,0]]
for j in range(2):
for k in range(2):
output[j][k] = int(coords[j][k] / scale)
return output
def reverse_cutout_and_resize(image, coordinates, scale_reduction, mode = 'nearest'):
blank = torch.zeros_like(image)
coordinates = scale_coords(coordinates, scale_reduction)
(y0, y1), (x0, x1) = coordinates
orig_cutout_shape = (y1-y0, x1-x0)
if orig_cutout_shape[0] <= 0 or orig_cutout_shape[1] <= 0:
return None
un_resized_img = F.interpolate(image, size=orig_cutout_shape, mode=mode)
blank[:,:,y0:y1,x0:x1] = un_resized_img
return blank
def compute_shared_coords(coords1, coords2, scale_reduction):
(y1_t, y1_b), (x1_l, x1_r) = scale_coords(coords1, scale_reduction)
(y2_t, y2_b), (x2_l, x2_r) = scale_coords(coords2, scale_reduction)
shared = ((max(y1_t, y2_t), min(y1_b, y2_b)),
(max(x1_l, x2_l), min(x1_r, x2_r)))
for s in shared:
if s == 0:
return None
return shared
def get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, img_orig_shape, interp_mode):
# Unflip the pixel projections
proj_pixel_one = flip_image_one_fn(proj_pixel_one)
proj_pixel_two = flip_image_two_fn(proj_pixel_two)
# Undo the cutout and resize, taking into account the scale reduction applied by the encoder.
scale_reduction = proj_pixel_one.shape[-1] / img_orig_shape[-1]
proj_pixel_one = reverse_cutout_and_resize(proj_pixel_one, cutout_coordinates_one, scale_reduction,
mode=interp_mode)
proj_pixel_two = reverse_cutout_and_resize(proj_pixel_two, cutout_coordinates_two, scale_reduction,
mode=interp_mode)
if proj_pixel_one is None or proj_pixel_two is None:
print("Could not extract projected image region. The selected cutout coordinates were smaller than the aggregate size of one latent block!")
return None
# Compute the shared coordinates for the two cutouts:
shared_coords = compute_shared_coords(cutout_coordinates_one, cutout_coordinates_two, scale_reduction)
if shared_coords is None:
print("No shared coordinates for this iteration (probably should just recompute those coordinates earlier..")
return None
(yt, yb), (xl, xr) = shared_coords
return proj_pixel_one[:, :, yt:yb, xl:xr], proj_pixel_two[:, :, yt:yb, xl:xr]
# augmentation utils # augmentation utils
class RandomApply(nn.Module): class RandomApply(nn.Module):
@ -172,8 +225,10 @@ class NetWrapper(nn.Module):
self, self,
*, *,
net, net,
projection_size, instance_projection_size,
projection_hidden_size, instance_projection_hidden_size,
pix_projection_size,
pix_projection_hidden_size,
layer_pixel = -2, layer_pixel = -2,
layer_instance = -2 layer_instance = -2
): ):
@ -185,8 +240,10 @@ class NetWrapper(nn.Module):
self.pixel_projector = None self.pixel_projector = None
self.instance_projector = None self.instance_projector = None
self.projection_size = projection_size self.instance_projection_size = instance_projection_size
self.projection_hidden_size = projection_hidden_size self.instance_projection_hidden_size = instance_projection_hidden_size
self.pix_projection_size = pix_projection_size
self.pix_projection_hidden_size = pix_projection_hidden_size
self.hidden_pixel = None self.hidden_pixel = None
self.hidden_instance = None self.hidden_instance = None
@ -218,13 +275,13 @@ class NetWrapper(nn.Module):
@singleton('pixel_projector') @singleton('pixel_projector')
def _get_pixel_projector(self, hidden): def _get_pixel_projector(self, hidden):
_, dim, *_ = hidden.shape _, dim, *_ = hidden.shape
projector = ConvMLP(dim, self.projection_size, self.projection_hidden_size) projector = ConvMLP(dim, self.pix_projection_size, self.pix_projection_hidden_size)
return projector.to(hidden) return projector.to(hidden)
@singleton('instance_projector') @singleton('instance_projector')
def _get_instance_projector(self, hidden): def _get_instance_projector(self, hidden):
_, dim = hidden.shape _, dim = hidden.shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size) projector = MLP(dim, self.instance_projection_size, self.instance_projection_hidden_size)
return projector.to(hidden) return projector.to(hidden)
def get_representation(self, x): def get_representation(self, x):
@ -252,7 +309,6 @@ class NetWrapper(nn.Module):
return pixel_projection, instance_projection return pixel_projection, instance_projection
# main class # main class
class PixelCL(nn.Module): class PixelCL(nn.Module):
def __init__( def __init__(
self, self,
@ -260,8 +316,10 @@ class PixelCL(nn.Module):
image_size, image_size,
hidden_layer_pixel = -2, hidden_layer_pixel = -2,
hidden_layer_instance = -2, hidden_layer_instance = -2,
projection_size = 256, instance_projection_size = 256,
projection_hidden_size = 2048, instance_projection_hidden_size = 2048,
pix_projection_size = 256,
pix_projection_hidden_size = 2048,
augment_fn = None, augment_fn = None,
augment_fn2 = None, augment_fn2 = None,
prob_rand_hflip = 0.25, prob_rand_hflip = 0.25,
@ -271,10 +329,10 @@ class PixelCL(nn.Module):
distance_thres = 0.7, distance_thres = 0.7,
similarity_temperature = 0.3, similarity_temperature = 0.3,
alpha = 1., alpha = 1.,
use_pixpro = True,
cutout_ratio_range = (0.6, 0.8), cutout_ratio_range = (0.6, 0.8),
cutout_interpolate_mode = 'nearest', cutout_interpolate_mode = 'nearest',
coord_cutout_interpolate_mode = 'bilinear' coord_cutout_interpolate_mode = 'bilinear',
max_latent_dim = None # This is in latent space, not image space, so dimensionality reduction of your network must be accounted for.
): ):
super().__init__() super().__init__()
@ -292,8 +350,10 @@ class PixelCL(nn.Module):
self.online_encoder = NetWrapper( self.online_encoder = NetWrapper(
net = net, net = net,
projection_size = projection_size, instance_projection_size = instance_projection_size,
projection_hidden_size = projection_hidden_size, instance_projection_hidden_size = instance_projection_hidden_size,
pix_projection_size = pix_projection_size,
pix_projection_hidden_size = pix_projection_hidden_size,
layer_pixel = hidden_layer_pixel, layer_pixel = hidden_layer_pixel,
layer_instance = hidden_layer_instance layer_instance = hidden_layer_instance
) )
@ -304,22 +364,20 @@ class PixelCL(nn.Module):
self.distance_thres = distance_thres self.distance_thres = distance_thres
self.similarity_temperature = similarity_temperature self.similarity_temperature = similarity_temperature
self.alpha = alpha self.alpha = alpha
self.max_latent_dim = max_latent_dim
self.use_pixpro = use_pixpro self.propagate_pixels = PPM(
chan = pix_projection_size,
if use_pixpro: num_layers = ppm_num_layers,
self.propagate_pixels = PPM( gamma = ppm_gamma
chan = projection_size, )
num_layers = ppm_num_layers,
gamma = ppm_gamma
)
self.cutout_ratio_range = cutout_ratio_range self.cutout_ratio_range = cutout_ratio_range
self.cutout_interpolate_mode = cutout_interpolate_mode self.cutout_interpolate_mode = cutout_interpolate_mode
self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode
# instance level predictor # instance level predictor
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) self.online_predictor = MLP(instance_projection_size, instance_projection_size, instance_projection_hidden_size)
# get device of network and make wrapper same device # get device of network and make wrapper same device
device = get_module_device(net) device = get_module_device(net)
@ -368,106 +426,74 @@ class PixelCL(nn.Module):
proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout) proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout)
proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout) proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout)
image_h, image_w = shape[2:] proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one,
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
proj_image_shape = proj_pixel_one.shape[2:] image_one_cutout.shape, self.cutout_interpolate_mode)
proj_image_h, proj_image_w = proj_image_shape sim_region_img_one, sim_region_img_two = get_shared_region(image_one_cutout, image_two_cutout, cutout_coordinates_one,
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
coordinates = torch.meshgrid( image_one_cutout.shape, self.cutout_interpolate_mode)
torch.arange(image_h, device = device), if proj_pixel_one is None or proj_pixel_two is None:
torch.arange(image_w, device = device) positive_pixel_pairs = 0
) else:
positive_pixel_pairs = proj_pixel_one.shape[-1] * proj_pixel_one.shape[-2]
coordinates = torch.stack(coordinates).unsqueeze(0).float()
coordinates /= math.sqrt(image_h ** 2 + image_w ** 2)
coordinates[:, 0] *= proj_image_h
coordinates[:, 1] *= proj_image_w
proj_coors_one = cutout_and_resize(coordinates, cutout_coordinates_one, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode)
proj_coors_two = cutout_and_resize(coordinates, cutout_coordinates_two, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode)
proj_coors_one = flip_image_one_fn(proj_coors_one)
proj_coors_two = flip_image_two_fn(proj_coors_two)
proj_coors_one, proj_coors_two = map(lambda t: rearrange(t, 'b c h w -> (b h w) c'), (proj_coors_one, proj_coors_two))
pdist = nn.PairwiseDistance(p = 2)
num_pixels = proj_coors_one.shape[0]
proj_coors_one_expanded = proj_coors_one[:, None].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)
proj_coors_two_expanded = proj_coors_two[None, :].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)
distance_matrix = pdist(proj_coors_one_expanded, proj_coors_two_expanded)
distance_matrix = distance_matrix.reshape(num_pixels, num_pixels)
positive_mask_one_two = distance_matrix < self.distance_thres
positive_mask_two_one = positive_mask_one_two.t()
with torch.no_grad(): with torch.no_grad():
target_encoder = self._get_target_encoder() target_encoder = self._get_target_encoder()
target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout) target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout)
target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout) target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout)
target_proj_pixel_one, target_proj_pixel_two = get_shared_region(target_proj_pixel_one, target_proj_pixel_two, cutout_coordinates_one,
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
image_one_cutout.shape, self.cutout_interpolate_mode)
# Apply max_latent_dim if needed.
_, _, pp_h, pp_w = proj_pixel_one.shape
if self.max_latent_dim and pp_h > self.max_latent_dim:
margin = pp_h - self.max_latent_dim
loc = random.randint(0, margin)
loce = loc + self.max_latent_dim
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, loc:loce, :], proj_pixel_two[:, :, loc:loce, :]
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, loc:loce, :], target_proj_pixel_two[:, :, loc:loce, :]
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, loc:loce, :], sim_region_img_two[:, :, loc:loce, :]
if self.max_latent_dim and pp_w > self.max_latent_dim:
margin = pp_w - self.max_latent_dim
loc = random.randint(0, margin)
loce = loc + self.max_latent_dim
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, :, loc:loce], proj_pixel_two[:, :, :, loc:loce]
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, :, loc:loce], target_proj_pixel_two[:, :, :, loc:loce]
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, :, loc:loce], sim_region_img_two[:, :, :, loc:loce]
# Stash these away for debugging purposes.
self.sim_region_img_one = sim_region_img_one.detach().clone()
self.sim_region_img_two = sim_region_img_two.detach().clone()
# flatten all the pixel projections # flatten all the pixel projections
flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)') flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two))) target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two)))
# get total number of positive pixel pairs
positive_pixel_pairs = positive_mask_one_two.sum()
# get instance level loss # get instance level loss
pred_instance_one = self.online_predictor(proj_instance_one) pred_instance_one = self.online_predictor(proj_instance_one)
pred_instance_two = self.online_predictor(proj_instance_two) pred_instance_two = self.online_predictor(proj_instance_two)
loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach()) loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach())
loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach()) loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach())
instance_loss = (loss_instance_one + loss_instance_two).mean() instance_loss = (loss_instance_one + loss_instance_two).mean()
if positive_pixel_pairs == 0: if positive_pixel_pairs == 0:
return instance_loss, 0 return instance_loss, 0
if not self.use_pixpro: # calculate pix pro loss
# calculate pix contrast loss propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
propagated_pixels_two = self.propagate_pixels(proj_pixel_two)
proj_pixel_one, proj_pixel_two = list(map(flatten, (proj_pixel_one, proj_pixel_two))) propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two)))
similarity_one_two = F.cosine_similarity(proj_pixel_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) / self.similarity_temperature propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1)
similarity_two_one = F.cosine_similarity(proj_pixel_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) / self.similarity_temperature propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1)
loss_pix_one_two = -torch.log( loss_pixpro_one_two = - propagated_similarity_one_two.mean()
similarity_one_two.masked_select(positive_mask_one_two[None, ...]).exp().sum() / loss_pixpro_two_one = - propagated_similarity_two_one.mean()
similarity_one_two.exp().sum()
)
loss_pix_two_one = -torch.log( pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2
similarity_two_one.masked_select(positive_mask_two_one[None, ...]).exp().sum() /
similarity_two_one.exp().sum()
)
pix_loss = (loss_pix_one_two + loss_pix_two_one) / 2
else:
# calculate pix pro loss
propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
propagated_pixels_two = self.propagate_pixels(proj_pixel_two)
propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two)))
propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1)
propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1)
loss_pixpro_one_two = - propagated_similarity_one_two.masked_select(positive_mask_one_two[None, ...]).mean()
loss_pixpro_two_one = - propagated_similarity_two_one.masked_select(positive_mask_two_one[None, ...]).mean()
pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2
# total loss # total loss
loss = pix_loss * self.alpha + instance_loss loss = pix_loss * self.alpha + instance_loss
return loss, positive_pixel_pairs return loss, positive_pixel_pairs
@ -477,6 +503,8 @@ class PixelCL(nn.Module):
return return
torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,))) torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,)))
torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,))) torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,)))
torchvision.utils.save_image(self.sim_region_img_one, os.path.join(path, "%i_sim1.png" % (step,)))
torchvision.utils.save_image(self.sim_region_img_two, os.path.join(path, "%i_sim2.png" % (step,)))
@register_model @register_model

View File

@ -1,79 +1,16 @@
# Resnet implementation that adds a u-net style up-conversion component to output values at a
# specified pixel density.
#
# The downsampling part of the network is compatible with the built-in torch resnet for use in
# transfer learning.
#
# Only resnet50 currently supported.
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3 from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
from torchvision.models.utils import load_state_dict_from_url from torchvision.models.utils import load_state_dict_from_url
import torchvision import torchvision
from models.arch_util import ConvBnRelu
from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint, opt_get from utils.util import checkpoint, opt_get
class ReverseBottleneck(nn.Module): class UResNet50_2(torchvision.models.resnet.ResNet):
def __init__(self, inplanes, planes, groups=1, passthrough=False,
base_width=64, dilation=1, norm_layer=None):
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
self.passthrough = passthrough
if passthrough:
self.integrate = conv1x1(inplanes*2, inplanes)
self.bn_integrate = norm_layer(inplanes)
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, groups, dilation)
self.bn2 = norm_layer(width)
self.residual_upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv1x1(width, width),
norm_layer(width),
)
self.conv3 = conv1x1(width, planes)
self.bn3 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv1x1(inplanes, planes),
norm_layer(planes),
)
def forward(self, x, passthrough=None):
if self.passthrough:
x = self.bn_integrate(self.integrate(torch.cat([x, passthrough], dim=1)))
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.residual_upsample(out)
out = self.conv3(out)
out = self.bn3(out)
identity = self.upsample(x)
out = out + identity
out = self.relu(out)
return out
class UResNet50(torchvision.models.resnet.ResNet):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, groups=1, width_per_group=64, replace_stride_with_dilation=None,
@ -82,6 +19,7 @@ class UResNet50(torchvision.models.resnet.ResNet):
replace_stride_with_dilation, norm_layer) replace_stride_with_dilation, norm_layer)
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
self.level_conv = ConvBnRelu(3, 64)
''' '''
# For reference: # For reference:
self.layer1 = self._make_layer(block, 64, layers[0]) self.layer1 = self._make_layer(block, 64, layers[0])
@ -95,29 +33,24 @@ class UResNet50(torchvision.models.resnet.ResNet):
uplayers = [] uplayers = []
inplanes = 2048 inplanes = 2048
first = True first = True
for i in range(2): div = [2,2,2,4,1]
uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first)) for i in range(5):
inplanes = inplanes // 2 uplayers.append(ReverseBottleneck(inplanes, inplanes // div[i], norm_layer=norm_layer, passthrough=not first))
inplanes = inplanes // div[i]
first = False first = False
self.uplayers = nn.ModuleList(uplayers) self.uplayers = nn.ModuleList(uplayers)
self.tail = nn.Sequential(conv1x1(1024, 512), self.tail = nn.Sequential(conv3x3(128, 64),
norm_layer(512), norm_layer(64),
nn.ReLU(), nn.ReLU(),
conv3x3(512, 512), conv1x1(64, out_dim))
norm_layer(512),
nn.ReLU(),
conv1x1(512, out_dim))
del self.fc # Not used in this implementation and just consumes a ton of GPU memory. del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
def _forward_impl(self, x): def _forward_impl(self, x):
# Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl, level = self.level_conv(x)
# except using checkpoints on the body conv layers. x0 = self.relu(self.bn1(self.conv1(x)))
x = self.conv1(x) x = self.maxpool(x0)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x1 = checkpoint(self.layer1, x) x1 = checkpoint(self.layer1, x)
x2 = checkpoint(self.layer2, x1) x2 = checkpoint(self.layer2, x1)
@ -127,18 +60,19 @@ class UResNet50(torchvision.models.resnet.ResNet):
x = checkpoint(self.uplayers[0], x4) x = checkpoint(self.uplayers[0], x4)
x = checkpoint(self.uplayers[1], x, x3) x = checkpoint(self.uplayers[1], x, x3)
#x = checkpoint(self.uplayers[2], x, x2) x = checkpoint(self.uplayers[2], x, x2)
#x = checkpoint(self.uplayers[3], x, x1) x = checkpoint(self.uplayers[3], x, x1)
x = checkpoint(self.uplayers[4], x, x0)
return checkpoint(self.tail, torch.cat([x, x2], dim=1)) return checkpoint(self.tail, torch.cat([x, level], dim=1))
def forward(self, x): def forward(self, x):
return self._forward_impl(x) return self._forward_impl(x)
@register_model @register_model
def register_u_resnet50(opt_net, opt): def register_u_resnet50_2(opt_net, opt):
model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim']) model = UResNet50_2(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
if opt_get(opt_net, ['use_pretrained_base'], False): if opt_get(opt_net, ['use_pretrained_base'], False):
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True) state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
@ -146,7 +80,8 @@ def register_u_resnet50(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
model = UResNet50(Bottleneck, [3,4,6,3]) model = UResNet50_2(Bottleneck, [3,4,6,3])
samp = torch.rand(1,3,224,224) samp = torch.rand(1,3,224,224)
model(samp) y = model(samp)
print(y.shape)
# For pixpro: attach to "tail.3" # For pixpro: attach to "tail.3"

View File

@ -14,16 +14,16 @@ def main():
split_img = False split_img = False
opt = {} opt = {}
opt['n_thread'] = 7 opt['n_thread'] = 7
opt['compression_level'] = 95 # JPEG compression quality rating. opt['compression_level'] = 90 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed. # compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file' opt['dest'] = 'file'
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\working'] opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\working']
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped' opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised'
opt['imgsize'] = 1024 opt['imgsize'] = 256
opt['bottom_crop'] = .1 opt['bottom_crop'] = 0.1
opt['keep_folder'] = True opt['keep_folder'] = False
save_folder = opt['save_folder'] save_folder = opt['save_folder']
if not osp.exists(save_folder): if not osp.exists(save_folder):
@ -58,7 +58,7 @@ class TiledDataset(data.Dataset):
# Perform explicit crops first. These are generally used to get rid of watermarks so we dont even want to # Perform explicit crops first. These are generally used to get rid of watermarks so we dont even want to
# consider these areas of the image. # consider these areas of the image.
if 'bottom_crop' in self.opt.keys(): if 'bottom_crop' in self.opt.keys() and self.opt['bottom_crop'] > 0:
bc = self.opt['bottom_crop'] bc = self.opt['bottom_crop']
if bc > 0 and bc < 1: if bc > 0 and bc < 1:
bc = int(bc * img.shape[0]) bc = int(bc * img.shape[0])
@ -83,9 +83,7 @@ class TiledDataset(data.Dataset):
pts = os.path.split(pts[0]) pts = os.path.split(pts[0])
output_folder = osp.join(self.opt['save_folder'], pts[-1]) output_folder = osp.join(self.opt['save_folder'], pts[-1])
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
if not basename.endswith(".jpg"): cv2.imwrite(osp.join(output_folder, basename), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
basename = basename + ".jpg"
cv2.imwrite(osp.join(output_folder, basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
return None return None
def __len__(self): def __len__(self):

View File

@ -295,7 +295,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_pixpro_resnet.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_resnet50.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()