Adjustments to pixpro to allow training against networks with arbitrarily large structural latents

- The pixpro latent now rescales the latent space instead of using a "coordinate vector", which
   **might** have performance implications.
- The latent against which the pixel loss is computed can now be a small, randomly sampled patch
   out of the entire latent, allowing further memory/computational discounts. Since the loss
   computation does not have a receptive field, this should not alter the loss.
- The instance projection size can now be separate from the pixel projection size.
- PixContrast removed entirely.
- ResUnet with full resolution added.
This commit is contained in:
James Betker 2021-01-12 09:17:45 -07:00
parent 34f8c8641f
commit d1007ccfe7
4 changed files with 155 additions and 194 deletions

View File

@ -66,6 +66,59 @@ def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'):
cutout_image = image[:, :, y0:y1, x0:x1]
return F.interpolate(cutout_image, size = output_size, mode = mode)
def scale_coords(coords, scale):
output = [[0,0],[0,0]]
for j in range(2):
for k in range(2):
output[j][k] = int(coords[j][k] / scale)
return output
def reverse_cutout_and_resize(image, coordinates, scale_reduction, mode = 'nearest'):
blank = torch.zeros_like(image)
coordinates = scale_coords(coordinates, scale_reduction)
(y0, y1), (x0, x1) = coordinates
orig_cutout_shape = (y1-y0, x1-x0)
if orig_cutout_shape[0] <= 0 or orig_cutout_shape[1] <= 0:
return None
un_resized_img = F.interpolate(image, size=orig_cutout_shape, mode=mode)
blank[:,:,y0:y1,x0:x1] = un_resized_img
return blank
def compute_shared_coords(coords1, coords2, scale_reduction):
(y1_t, y1_b), (x1_l, x1_r) = scale_coords(coords1, scale_reduction)
(y2_t, y2_b), (x2_l, x2_r) = scale_coords(coords2, scale_reduction)
shared = ((max(y1_t, y2_t), min(y1_b, y2_b)),
(max(x1_l, x2_l), min(x1_r, x2_r)))
for s in shared:
if s == 0:
return None
return shared
def get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, img_orig_shape, interp_mode):
# Unflip the pixel projections
proj_pixel_one = flip_image_one_fn(proj_pixel_one)
proj_pixel_two = flip_image_two_fn(proj_pixel_two)
# Undo the cutout and resize, taking into account the scale reduction applied by the encoder.
scale_reduction = proj_pixel_one.shape[-1] / img_orig_shape[-1]
proj_pixel_one = reverse_cutout_and_resize(proj_pixel_one, cutout_coordinates_one, scale_reduction,
mode=interp_mode)
proj_pixel_two = reverse_cutout_and_resize(proj_pixel_two, cutout_coordinates_two, scale_reduction,
mode=interp_mode)
if proj_pixel_one is None or proj_pixel_two is None:
print("Could not extract projected image region. The selected cutout coordinates were smaller than the aggregate size of one latent block!")
return None
# Compute the shared coordinates for the two cutouts:
shared_coords = compute_shared_coords(cutout_coordinates_one, cutout_coordinates_two, scale_reduction)
if shared_coords is None:
print("No shared coordinates for this iteration (probably should just recompute those coordinates earlier..")
return None
(yt, yb), (xl, xr) = shared_coords
return proj_pixel_one[:, :, yt:yb, xl:xr], proj_pixel_two[:, :, yt:yb, xl:xr]
# augmentation utils
class RandomApply(nn.Module):
@ -172,8 +225,10 @@ class NetWrapper(nn.Module):
self,
*,
net,
projection_size,
projection_hidden_size,
instance_projection_size,
instance_projection_hidden_size,
pix_projection_size,
pix_projection_hidden_size,
layer_pixel = -2,
layer_instance = -2
):
@ -185,8 +240,10 @@ class NetWrapper(nn.Module):
self.pixel_projector = None
self.instance_projector = None
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size
self.instance_projection_size = instance_projection_size
self.instance_projection_hidden_size = instance_projection_hidden_size
self.pix_projection_size = pix_projection_size
self.pix_projection_hidden_size = pix_projection_hidden_size
self.hidden_pixel = None
self.hidden_instance = None
@ -218,13 +275,13 @@ class NetWrapper(nn.Module):
@singleton('pixel_projector')
def _get_pixel_projector(self, hidden):
_, dim, *_ = hidden.shape
projector = ConvMLP(dim, self.projection_size, self.projection_hidden_size)
projector = ConvMLP(dim, self.pix_projection_size, self.pix_projection_hidden_size)
return projector.to(hidden)
@singleton('instance_projector')
def _get_instance_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
projector = MLP(dim, self.instance_projection_size, self.instance_projection_hidden_size)
return projector.to(hidden)
def get_representation(self, x):
@ -252,7 +309,6 @@ class NetWrapper(nn.Module):
return pixel_projection, instance_projection
# main class
class PixelCL(nn.Module):
def __init__(
self,
@ -260,8 +316,10 @@ class PixelCL(nn.Module):
image_size,
hidden_layer_pixel = -2,
hidden_layer_instance = -2,
projection_size = 256,
projection_hidden_size = 2048,
instance_projection_size = 256,
instance_projection_hidden_size = 2048,
pix_projection_size = 256,
pix_projection_hidden_size = 2048,
augment_fn = None,
augment_fn2 = None,
prob_rand_hflip = 0.25,
@ -271,10 +329,10 @@ class PixelCL(nn.Module):
distance_thres = 0.7,
similarity_temperature = 0.3,
alpha = 1.,
use_pixpro = True,
cutout_ratio_range = (0.6, 0.8),
cutout_interpolate_mode = 'nearest',
coord_cutout_interpolate_mode = 'bilinear'
coord_cutout_interpolate_mode = 'bilinear',
max_latent_dim = None # This is in latent space, not image space, so dimensionality reduction of your network must be accounted for.
):
super().__init__()
@ -292,8 +350,10 @@ class PixelCL(nn.Module):
self.online_encoder = NetWrapper(
net = net,
projection_size = projection_size,
projection_hidden_size = projection_hidden_size,
instance_projection_size = instance_projection_size,
instance_projection_hidden_size = instance_projection_hidden_size,
pix_projection_size = pix_projection_size,
pix_projection_hidden_size = pix_projection_hidden_size,
layer_pixel = hidden_layer_pixel,
layer_instance = hidden_layer_instance
)
@ -304,22 +364,20 @@ class PixelCL(nn.Module):
self.distance_thres = distance_thres
self.similarity_temperature = similarity_temperature
self.alpha = alpha
self.max_latent_dim = max_latent_dim
self.use_pixpro = use_pixpro
if use_pixpro:
self.propagate_pixels = PPM(
chan = projection_size,
num_layers = ppm_num_layers,
gamma = ppm_gamma
)
self.propagate_pixels = PPM(
chan = pix_projection_size,
num_layers = ppm_num_layers,
gamma = ppm_gamma
)
self.cutout_ratio_range = cutout_ratio_range
self.cutout_interpolate_mode = cutout_interpolate_mode
self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode
# instance level predictor
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
self.online_predictor = MLP(instance_projection_size, instance_projection_size, instance_projection_hidden_size)
# get device of network and make wrapper same device
device = get_module_device(net)
@ -368,106 +426,74 @@ class PixelCL(nn.Module):
proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout)
proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout)
image_h, image_w = shape[2:]
proj_image_shape = proj_pixel_one.shape[2:]
proj_image_h, proj_image_w = proj_image_shape
coordinates = torch.meshgrid(
torch.arange(image_h, device = device),
torch.arange(image_w, device = device)
)
coordinates = torch.stack(coordinates).unsqueeze(0).float()
coordinates /= math.sqrt(image_h ** 2 + image_w ** 2)
coordinates[:, 0] *= proj_image_h
coordinates[:, 1] *= proj_image_w
proj_coors_one = cutout_and_resize(coordinates, cutout_coordinates_one, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode)
proj_coors_two = cutout_and_resize(coordinates, cutout_coordinates_two, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode)
proj_coors_one = flip_image_one_fn(proj_coors_one)
proj_coors_two = flip_image_two_fn(proj_coors_two)
proj_coors_one, proj_coors_two = map(lambda t: rearrange(t, 'b c h w -> (b h w) c'), (proj_coors_one, proj_coors_two))
pdist = nn.PairwiseDistance(p = 2)
num_pixels = proj_coors_one.shape[0]
proj_coors_one_expanded = proj_coors_one[:, None].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)
proj_coors_two_expanded = proj_coors_two[None, :].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)
distance_matrix = pdist(proj_coors_one_expanded, proj_coors_two_expanded)
distance_matrix = distance_matrix.reshape(num_pixels, num_pixels)
positive_mask_one_two = distance_matrix < self.distance_thres
positive_mask_two_one = positive_mask_one_two.t()
proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one,
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
image_one_cutout.shape, self.cutout_interpolate_mode)
sim_region_img_one, sim_region_img_two = get_shared_region(image_one_cutout, image_two_cutout, cutout_coordinates_one,
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
image_one_cutout.shape, self.cutout_interpolate_mode)
if proj_pixel_one is None or proj_pixel_two is None:
positive_pixel_pairs = 0
else:
positive_pixel_pairs = proj_pixel_one.shape[-1] * proj_pixel_one.shape[-2]
with torch.no_grad():
target_encoder = self._get_target_encoder()
target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout)
target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout)
target_proj_pixel_one, target_proj_pixel_two = get_shared_region(target_proj_pixel_one, target_proj_pixel_two, cutout_coordinates_one,
cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
image_one_cutout.shape, self.cutout_interpolate_mode)
# Apply max_latent_dim if needed.
_, _, pp_h, pp_w = proj_pixel_one.shape
if self.max_latent_dim and pp_h > self.max_latent_dim:
margin = pp_h - self.max_latent_dim
loc = random.randint(0, margin)
loce = loc + self.max_latent_dim
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, loc:loce, :], proj_pixel_two[:, :, loc:loce, :]
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, loc:loce, :], target_proj_pixel_two[:, :, loc:loce, :]
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, loc:loce, :], sim_region_img_two[:, :, loc:loce, :]
if self.max_latent_dim and pp_w > self.max_latent_dim:
margin = pp_w - self.max_latent_dim
loc = random.randint(0, margin)
loce = loc + self.max_latent_dim
proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, :, loc:loce], proj_pixel_two[:, :, :, loc:loce]
target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, :, loc:loce], target_proj_pixel_two[:, :, :, loc:loce]
sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, :, loc:loce], sim_region_img_two[:, :, :, loc:loce]
# Stash these away for debugging purposes.
self.sim_region_img_one = sim_region_img_one.detach().clone()
self.sim_region_img_two = sim_region_img_two.detach().clone()
# flatten all the pixel projections
flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two)))
# get total number of positive pixel pairs
positive_pixel_pairs = positive_mask_one_two.sum()
# get instance level loss
pred_instance_one = self.online_predictor(proj_instance_one)
pred_instance_two = self.online_predictor(proj_instance_two)
loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach())
loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach())
instance_loss = (loss_instance_one + loss_instance_two).mean()
if positive_pixel_pairs == 0:
return instance_loss, 0
if not self.use_pixpro:
# calculate pix contrast loss
# calculate pix pro loss
propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
propagated_pixels_two = self.propagate_pixels(proj_pixel_two)
proj_pixel_one, proj_pixel_two = list(map(flatten, (proj_pixel_one, proj_pixel_two)))
propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two)))
similarity_one_two = F.cosine_similarity(proj_pixel_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) / self.similarity_temperature
similarity_two_one = F.cosine_similarity(proj_pixel_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) / self.similarity_temperature
propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1)
propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1)
loss_pix_one_two = -torch.log(
similarity_one_two.masked_select(positive_mask_one_two[None, ...]).exp().sum() /
similarity_one_two.exp().sum()
)
loss_pixpro_one_two = - propagated_similarity_one_two.mean()
loss_pixpro_two_one = - propagated_similarity_two_one.mean()
loss_pix_two_one = -torch.log(
similarity_two_one.masked_select(positive_mask_two_one[None, ...]).exp().sum() /
similarity_two_one.exp().sum()
)
pix_loss = (loss_pix_one_two + loss_pix_two_one) / 2
else:
# calculate pix pro loss
propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
propagated_pixels_two = self.propagate_pixels(proj_pixel_two)
propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two)))
propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1)
propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1)
loss_pixpro_one_two = - propagated_similarity_one_two.masked_select(positive_mask_one_two[None, ...]).mean()
loss_pixpro_two_one = - propagated_similarity_two_one.masked_select(positive_mask_two_one[None, ...]).mean()
pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2
pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2
# total loss
loss = pix_loss * self.alpha + instance_loss
return loss, positive_pixel_pairs
@ -477,6 +503,8 @@ class PixelCL(nn.Module):
return
torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,)))
torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,)))
torchvision.utils.save_image(self.sim_region_img_one, os.path.join(path, "%i_sim1.png" % (step,)))
torchvision.utils.save_image(self.sim_region_img_two, os.path.join(path, "%i_sim2.png" % (step,)))
@register_model

View File

@ -1,79 +1,16 @@
# Resnet implementation that adds a u-net style up-conversion component to output values at a
# specified pixel density.
#
# The downsampling part of the network is compatible with the built-in torch resnet for use in
# transfer learning.
#
# Only resnet50 currently supported.
import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
from torchvision.models.utils import load_state_dict_from_url
import torchvision
from models.arch_util import ConvBnRelu
from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
class ReverseBottleneck(nn.Module):
def __init__(self, inplanes, planes, groups=1, passthrough=False,
base_width=64, dilation=1, norm_layer=None):
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
self.passthrough = passthrough
if passthrough:
self.integrate = conv1x1(inplanes*2, inplanes)
self.bn_integrate = norm_layer(inplanes)
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, groups, dilation)
self.bn2 = norm_layer(width)
self.residual_upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv1x1(width, width),
norm_layer(width),
)
self.conv3 = conv1x1(width, planes)
self.bn3 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv1x1(inplanes, planes),
norm_layer(planes),
)
def forward(self, x, passthrough=None):
if self.passthrough:
x = self.bn_integrate(self.integrate(torch.cat([x, passthrough], dim=1)))
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.residual_upsample(out)
out = self.conv3(out)
out = self.bn3(out)
identity = self.upsample(x)
out = out + identity
out = self.relu(out)
return out
class UResNet50(torchvision.models.resnet.ResNet):
class UResNet50_2(torchvision.models.resnet.ResNet):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
@ -82,6 +19,7 @@ class UResNet50(torchvision.models.resnet.ResNet):
replace_stride_with_dilation, norm_layer)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self.level_conv = ConvBnRelu(3, 64)
'''
# For reference:
self.layer1 = self._make_layer(block, 64, layers[0])
@ -95,29 +33,24 @@ class UResNet50(torchvision.models.resnet.ResNet):
uplayers = []
inplanes = 2048
first = True
for i in range(2):
uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first))
inplanes = inplanes // 2
div = [2,2,2,4,1]
for i in range(5):
uplayers.append(ReverseBottleneck(inplanes, inplanes // div[i], norm_layer=norm_layer, passthrough=not first))
inplanes = inplanes // div[i]
first = False
self.uplayers = nn.ModuleList(uplayers)
self.tail = nn.Sequential(conv1x1(1024, 512),
norm_layer(512),
self.tail = nn.Sequential(conv3x3(128, 64),
norm_layer(64),
nn.ReLU(),
conv3x3(512, 512),
norm_layer(512),
nn.ReLU(),
conv1x1(512, out_dim))
conv1x1(64, out_dim))
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
def _forward_impl(self, x):
# Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl,
# except using checkpoints on the body conv layers.
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
level = self.level_conv(x)
x0 = self.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x0)
x1 = checkpoint(self.layer1, x)
x2 = checkpoint(self.layer2, x1)
@ -127,18 +60,19 @@ class UResNet50(torchvision.models.resnet.ResNet):
x = checkpoint(self.uplayers[0], x4)
x = checkpoint(self.uplayers[1], x, x3)
#x = checkpoint(self.uplayers[2], x, x2)
#x = checkpoint(self.uplayers[3], x, x1)
x = checkpoint(self.uplayers[2], x, x2)
x = checkpoint(self.uplayers[3], x, x1)
x = checkpoint(self.uplayers[4], x, x0)
return checkpoint(self.tail, torch.cat([x, x2], dim=1))
return checkpoint(self.tail, torch.cat([x, level], dim=1))
def forward(self, x):
return self._forward_impl(x)
@register_model
def register_u_resnet50(opt_net, opt):
model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
def register_u_resnet50_2(opt_net, opt):
model = UResNet50_2(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
if opt_get(opt_net, ['use_pretrained_base'], False):
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
model.load_state_dict(state_dict, strict=False)
@ -146,7 +80,8 @@ def register_u_resnet50(opt_net, opt):
if __name__ == '__main__':
model = UResNet50(Bottleneck, [3,4,6,3])
model = UResNet50_2(Bottleneck, [3,4,6,3])
samp = torch.rand(1,3,224,224)
model(samp)
y = model(samp)
print(y.shape)
# For pixpro: attach to "tail.3"

View File

@ -14,16 +14,16 @@ def main():
split_img = False
opt = {}
opt['n_thread'] = 7
opt['compression_level'] = 95 # JPEG compression quality rating.
opt['compression_level'] = 90 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file'
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\working']
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped'
opt['imgsize'] = 1024
opt['bottom_crop'] = .1
opt['keep_folder'] = True
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised'
opt['imgsize'] = 256
opt['bottom_crop'] = 0.1
opt['keep_folder'] = False
save_folder = opt['save_folder']
if not osp.exists(save_folder):
@ -58,7 +58,7 @@ class TiledDataset(data.Dataset):
# Perform explicit crops first. These are generally used to get rid of watermarks so we dont even want to
# consider these areas of the image.
if 'bottom_crop' in self.opt.keys():
if 'bottom_crop' in self.opt.keys() and self.opt['bottom_crop'] > 0:
bc = self.opt['bottom_crop']
if bc > 0 and bc < 1:
bc = int(bc * img.shape[0])
@ -83,9 +83,7 @@ class TiledDataset(data.Dataset):
pts = os.path.split(pts[0])
output_folder = osp.join(self.opt['save_folder'], pts[-1])
os.makedirs(output_folder, exist_ok=True)
if not basename.endswith(".jpg"):
basename = basename + ".jpg"
cv2.imwrite(osp.join(output_folder, basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
cv2.imwrite(osp.join(output_folder, basename), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
return None
def __len__(self):

View File

@ -295,7 +295,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_pixpro_resnet.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_resnet50.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()