diff --git a/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py b/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py index de5e11e8..1536a6a7 100644 --- a/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py +++ b/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py @@ -66,6 +66,59 @@ def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'): cutout_image = image[:, :, y0:y1, x0:x1] return F.interpolate(cutout_image, size = output_size, mode = mode) +def scale_coords(coords, scale): + output = [[0,0],[0,0]] + for j in range(2): + for k in range(2): + output[j][k] = int(coords[j][k] / scale) + return output + +def reverse_cutout_and_resize(image, coordinates, scale_reduction, mode = 'nearest'): + blank = torch.zeros_like(image) + coordinates = scale_coords(coordinates, scale_reduction) + (y0, y1), (x0, x1) = coordinates + orig_cutout_shape = (y1-y0, x1-x0) + if orig_cutout_shape[0] <= 0 or orig_cutout_shape[1] <= 0: + return None + + un_resized_img = F.interpolate(image, size=orig_cutout_shape, mode=mode) + blank[:,:,y0:y1,x0:x1] = un_resized_img + return blank + +def compute_shared_coords(coords1, coords2, scale_reduction): + (y1_t, y1_b), (x1_l, x1_r) = scale_coords(coords1, scale_reduction) + (y2_t, y2_b), (x2_l, x2_r) = scale_coords(coords2, scale_reduction) + shared = ((max(y1_t, y2_t), min(y1_b, y2_b)), + (max(x1_l, x2_l), min(x1_r, x2_r))) + for s in shared: + if s == 0: + return None + return shared + +def get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, img_orig_shape, interp_mode): + # Unflip the pixel projections + proj_pixel_one = flip_image_one_fn(proj_pixel_one) + proj_pixel_two = flip_image_two_fn(proj_pixel_two) + + # Undo the cutout and resize, taking into account the scale reduction applied by the encoder. + scale_reduction = proj_pixel_one.shape[-1] / img_orig_shape[-1] + proj_pixel_one = reverse_cutout_and_resize(proj_pixel_one, cutout_coordinates_one, scale_reduction, + mode=interp_mode) + proj_pixel_two = reverse_cutout_and_resize(proj_pixel_two, cutout_coordinates_two, scale_reduction, + mode=interp_mode) + if proj_pixel_one is None or proj_pixel_two is None: + print("Could not extract projected image region. The selected cutout coordinates were smaller than the aggregate size of one latent block!") + return None + + # Compute the shared coordinates for the two cutouts: + shared_coords = compute_shared_coords(cutout_coordinates_one, cutout_coordinates_two, scale_reduction) + if shared_coords is None: + print("No shared coordinates for this iteration (probably should just recompute those coordinates earlier..") + return None + (yt, yb), (xl, xr) = shared_coords + + return proj_pixel_one[:, :, yt:yb, xl:xr], proj_pixel_two[:, :, yt:yb, xl:xr] + # augmentation utils class RandomApply(nn.Module): @@ -172,8 +225,10 @@ class NetWrapper(nn.Module): self, *, net, - projection_size, - projection_hidden_size, + instance_projection_size, + instance_projection_hidden_size, + pix_projection_size, + pix_projection_hidden_size, layer_pixel = -2, layer_instance = -2 ): @@ -185,8 +240,10 @@ class NetWrapper(nn.Module): self.pixel_projector = None self.instance_projector = None - self.projection_size = projection_size - self.projection_hidden_size = projection_hidden_size + self.instance_projection_size = instance_projection_size + self.instance_projection_hidden_size = instance_projection_hidden_size + self.pix_projection_size = pix_projection_size + self.pix_projection_hidden_size = pix_projection_hidden_size self.hidden_pixel = None self.hidden_instance = None @@ -218,13 +275,13 @@ class NetWrapper(nn.Module): @singleton('pixel_projector') def _get_pixel_projector(self, hidden): _, dim, *_ = hidden.shape - projector = ConvMLP(dim, self.projection_size, self.projection_hidden_size) + projector = ConvMLP(dim, self.pix_projection_size, self.pix_projection_hidden_size) return projector.to(hidden) @singleton('instance_projector') def _get_instance_projector(self, hidden): _, dim = hidden.shape - projector = MLP(dim, self.projection_size, self.projection_hidden_size) + projector = MLP(dim, self.instance_projection_size, self.instance_projection_hidden_size) return projector.to(hidden) def get_representation(self, x): @@ -252,7 +309,6 @@ class NetWrapper(nn.Module): return pixel_projection, instance_projection # main class - class PixelCL(nn.Module): def __init__( self, @@ -260,8 +316,10 @@ class PixelCL(nn.Module): image_size, hidden_layer_pixel = -2, hidden_layer_instance = -2, - projection_size = 256, - projection_hidden_size = 2048, + instance_projection_size = 256, + instance_projection_hidden_size = 2048, + pix_projection_size = 256, + pix_projection_hidden_size = 2048, augment_fn = None, augment_fn2 = None, prob_rand_hflip = 0.25, @@ -271,10 +329,10 @@ class PixelCL(nn.Module): distance_thres = 0.7, similarity_temperature = 0.3, alpha = 1., - use_pixpro = True, cutout_ratio_range = (0.6, 0.8), cutout_interpolate_mode = 'nearest', - coord_cutout_interpolate_mode = 'bilinear' + coord_cutout_interpolate_mode = 'bilinear', + max_latent_dim = None # This is in latent space, not image space, so dimensionality reduction of your network must be accounted for. ): super().__init__() @@ -292,8 +350,10 @@ class PixelCL(nn.Module): self.online_encoder = NetWrapper( net = net, - projection_size = projection_size, - projection_hidden_size = projection_hidden_size, + instance_projection_size = instance_projection_size, + instance_projection_hidden_size = instance_projection_hidden_size, + pix_projection_size = pix_projection_size, + pix_projection_hidden_size = pix_projection_hidden_size, layer_pixel = hidden_layer_pixel, layer_instance = hidden_layer_instance ) @@ -304,22 +364,20 @@ class PixelCL(nn.Module): self.distance_thres = distance_thres self.similarity_temperature = similarity_temperature self.alpha = alpha + self.max_latent_dim = max_latent_dim - self.use_pixpro = use_pixpro - - if use_pixpro: - self.propagate_pixels = PPM( - chan = projection_size, - num_layers = ppm_num_layers, - gamma = ppm_gamma - ) + self.propagate_pixels = PPM( + chan = pix_projection_size, + num_layers = ppm_num_layers, + gamma = ppm_gamma + ) self.cutout_ratio_range = cutout_ratio_range self.cutout_interpolate_mode = cutout_interpolate_mode self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode # instance level predictor - self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) + self.online_predictor = MLP(instance_projection_size, instance_projection_size, instance_projection_hidden_size) # get device of network and make wrapper same device device = get_module_device(net) @@ -368,106 +426,74 @@ class PixelCL(nn.Module): proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout) proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout) - image_h, image_w = shape[2:] - - proj_image_shape = proj_pixel_one.shape[2:] - proj_image_h, proj_image_w = proj_image_shape - - coordinates = torch.meshgrid( - torch.arange(image_h, device = device), - torch.arange(image_w, device = device) - ) - - coordinates = torch.stack(coordinates).unsqueeze(0).float() - coordinates /= math.sqrt(image_h ** 2 + image_w ** 2) - coordinates[:, 0] *= proj_image_h - coordinates[:, 1] *= proj_image_w - - proj_coors_one = cutout_and_resize(coordinates, cutout_coordinates_one, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode) - proj_coors_two = cutout_and_resize(coordinates, cutout_coordinates_two, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode) - - proj_coors_one = flip_image_one_fn(proj_coors_one) - proj_coors_two = flip_image_two_fn(proj_coors_two) - - proj_coors_one, proj_coors_two = map(lambda t: rearrange(t, 'b c h w -> (b h w) c'), (proj_coors_one, proj_coors_two)) - pdist = nn.PairwiseDistance(p = 2) - - num_pixels = proj_coors_one.shape[0] - - proj_coors_one_expanded = proj_coors_one[:, None].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2) - proj_coors_two_expanded = proj_coors_two[None, :].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2) - - distance_matrix = pdist(proj_coors_one_expanded, proj_coors_two_expanded) - distance_matrix = distance_matrix.reshape(num_pixels, num_pixels) - - positive_mask_one_two = distance_matrix < self.distance_thres - positive_mask_two_one = positive_mask_one_two.t() + proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, + cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, + image_one_cutout.shape, self.cutout_interpolate_mode) + sim_region_img_one, sim_region_img_two = get_shared_region(image_one_cutout, image_two_cutout, cutout_coordinates_one, + cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, + image_one_cutout.shape, self.cutout_interpolate_mode) + if proj_pixel_one is None or proj_pixel_two is None: + positive_pixel_pairs = 0 + else: + positive_pixel_pairs = proj_pixel_one.shape[-1] * proj_pixel_one.shape[-2] with torch.no_grad(): target_encoder = self._get_target_encoder() target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout) target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout) + target_proj_pixel_one, target_proj_pixel_two = get_shared_region(target_proj_pixel_one, target_proj_pixel_two, cutout_coordinates_one, + cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, + image_one_cutout.shape, self.cutout_interpolate_mode) + + # Apply max_latent_dim if needed. + _, _, pp_h, pp_w = proj_pixel_one.shape + if self.max_latent_dim and pp_h > self.max_latent_dim: + margin = pp_h - self.max_latent_dim + loc = random.randint(0, margin) + loce = loc + self.max_latent_dim + proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, loc:loce, :], proj_pixel_two[:, :, loc:loce, :] + target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, loc:loce, :], target_proj_pixel_two[:, :, loc:loce, :] + sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, loc:loce, :], sim_region_img_two[:, :, loc:loce, :] + if self.max_latent_dim and pp_w > self.max_latent_dim: + margin = pp_w - self.max_latent_dim + loc = random.randint(0, margin) + loce = loc + self.max_latent_dim + proj_pixel_one, proj_pixel_two = proj_pixel_one[:, :, :, loc:loce], proj_pixel_two[:, :, :, loc:loce] + target_proj_pixel_one, target_proj_pixel_two = target_proj_pixel_one[:, :, :, loc:loce], target_proj_pixel_two[:, :, :, loc:loce] + sim_region_img_one, sim_region_img_two = sim_region_img_one[:, :, :, loc:loce], sim_region_img_two[:, :, :, loc:loce] + # Stash these away for debugging purposes. + self.sim_region_img_one = sim_region_img_one.detach().clone() + self.sim_region_img_two = sim_region_img_two.detach().clone() # flatten all the pixel projections - flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)') - target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two))) - # get total number of positive pixel pairs - - positive_pixel_pairs = positive_mask_one_two.sum() - # get instance level loss - pred_instance_one = self.online_predictor(proj_instance_one) pred_instance_two = self.online_predictor(proj_instance_two) - loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach()) loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach()) - instance_loss = (loss_instance_one + loss_instance_two).mean() if positive_pixel_pairs == 0: return instance_loss, 0 - if not self.use_pixpro: - # calculate pix contrast loss + # calculate pix pro loss + propagated_pixels_one = self.propagate_pixels(proj_pixel_one) + propagated_pixels_two = self.propagate_pixels(proj_pixel_two) - proj_pixel_one, proj_pixel_two = list(map(flatten, (proj_pixel_one, proj_pixel_two))) + propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two))) - similarity_one_two = F.cosine_similarity(proj_pixel_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) / self.similarity_temperature - similarity_two_one = F.cosine_similarity(proj_pixel_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) / self.similarity_temperature + propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) + propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) - loss_pix_one_two = -torch.log( - similarity_one_two.masked_select(positive_mask_one_two[None, ...]).exp().sum() / - similarity_one_two.exp().sum() - ) + loss_pixpro_one_two = - propagated_similarity_one_two.mean() + loss_pixpro_two_one = - propagated_similarity_two_one.mean() - loss_pix_two_one = -torch.log( - similarity_two_one.masked_select(positive_mask_two_one[None, ...]).exp().sum() / - similarity_two_one.exp().sum() - ) - - pix_loss = (loss_pix_one_two + loss_pix_two_one) / 2 - else: - # calculate pix pro loss - - propagated_pixels_one = self.propagate_pixels(proj_pixel_one) - propagated_pixels_two = self.propagate_pixels(proj_pixel_two) - - propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two))) - - propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) - propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) - - loss_pixpro_one_two = - propagated_similarity_one_two.masked_select(positive_mask_one_two[None, ...]).mean() - loss_pixpro_two_one = - propagated_similarity_two_one.masked_select(positive_mask_two_one[None, ...]).mean() - - pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2 + pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2 # total loss - loss = pix_loss * self.alpha + instance_loss return loss, positive_pixel_pairs @@ -477,6 +503,8 @@ class PixelCL(nn.Module): return torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,))) torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,))) + torchvision.utils.save_image(self.sim_region_img_one, os.path.join(path, "%i_sim1.png" % (step,))) + torchvision.utils.save_image(self.sim_region_img_two, os.path.join(path, "%i_sim2.png" % (step,))) @register_model diff --git a/codes/models/pixel_level_contrastive_learning/resnet_unet_2.py b/codes/models/pixel_level_contrastive_learning/resnet_unet_2.py index 46bc747f..dd7ed05c 100644 --- a/codes/models/pixel_level_contrastive_learning/resnet_unet_2.py +++ b/codes/models/pixel_level_contrastive_learning/resnet_unet_2.py @@ -1,79 +1,16 @@ -# Resnet implementation that adds a u-net style up-conversion component to output values at a -# specified pixel density. -# -# The downsampling part of the network is compatible with the built-in torch resnet for use in -# transfer learning. -# -# Only resnet50 currently supported. - import torch import torch.nn as nn from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3 from torchvision.models.utils import load_state_dict_from_url import torchvision - +from models.arch_util import ConvBnRelu +from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck from trainer.networks import register_model from utils.util import checkpoint, opt_get -class ReverseBottleneck(nn.Module): - - def __init__(self, inplanes, planes, groups=1, passthrough=False, - base_width=64, dilation=1, norm_layer=None): - super().__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups - self.passthrough = passthrough - if passthrough: - self.integrate = conv1x1(inplanes*2, inplanes) - self.bn_integrate = norm_layer(inplanes) - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, groups, dilation) - self.bn2 = norm_layer(width) - self.residual_upsample = nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), - conv1x1(width, width), - norm_layer(width), - ) - self.conv3 = conv1x1(width, planes) - self.bn3 = norm_layer(planes) - self.relu = nn.ReLU(inplace=True) - self.upsample = nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), - conv1x1(inplanes, planes), - norm_layer(planes), - ) - - def forward(self, x, passthrough=None): - if self.passthrough: - x = self.bn_integrate(self.integrate(torch.cat([x, passthrough], dim=1))) - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.residual_upsample(out) - - out = self.conv3(out) - out = self.bn3(out) - - identity = self.upsample(x) - - out = out + identity - out = self.relu(out) - - return out - - -class UResNet50(torchvision.models.resnet.ResNet): +class UResNet50_2(torchvision.models.resnet.ResNet): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, @@ -82,6 +19,7 @@ class UResNet50(torchvision.models.resnet.ResNet): replace_stride_with_dilation, norm_layer) if norm_layer is None: norm_layer = nn.BatchNorm2d + self.level_conv = ConvBnRelu(3, 64) ''' # For reference: self.layer1 = self._make_layer(block, 64, layers[0]) @@ -95,29 +33,24 @@ class UResNet50(torchvision.models.resnet.ResNet): uplayers = [] inplanes = 2048 first = True - for i in range(2): - uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first)) - inplanes = inplanes // 2 + div = [2,2,2,4,1] + for i in range(5): + uplayers.append(ReverseBottleneck(inplanes, inplanes // div[i], norm_layer=norm_layer, passthrough=not first)) + inplanes = inplanes // div[i] first = False self.uplayers = nn.ModuleList(uplayers) - self.tail = nn.Sequential(conv1x1(1024, 512), - norm_layer(512), + self.tail = nn.Sequential(conv3x3(128, 64), + norm_layer(64), nn.ReLU(), - conv3x3(512, 512), - norm_layer(512), - nn.ReLU(), - conv1x1(512, out_dim)) + conv1x1(64, out_dim)) del self.fc # Not used in this implementation and just consumes a ton of GPU memory. def _forward_impl(self, x): - # Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl, - # except using checkpoints on the body conv layers. - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) + level = self.level_conv(x) + x0 = self.relu(self.bn1(self.conv1(x))) + x = self.maxpool(x0) x1 = checkpoint(self.layer1, x) x2 = checkpoint(self.layer2, x1) @@ -127,18 +60,19 @@ class UResNet50(torchvision.models.resnet.ResNet): x = checkpoint(self.uplayers[0], x4) x = checkpoint(self.uplayers[1], x, x3) - #x = checkpoint(self.uplayers[2], x, x2) - #x = checkpoint(self.uplayers[3], x, x1) + x = checkpoint(self.uplayers[2], x, x2) + x = checkpoint(self.uplayers[3], x, x1) + x = checkpoint(self.uplayers[4], x, x0) - return checkpoint(self.tail, torch.cat([x, x2], dim=1)) + return checkpoint(self.tail, torch.cat([x, level], dim=1)) def forward(self, x): return self._forward_impl(x) @register_model -def register_u_resnet50(opt_net, opt): - model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim']) +def register_u_resnet50_2(opt_net, opt): + model = UResNet50_2(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim']) if opt_get(opt_net, ['use_pretrained_base'], False): state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True) model.load_state_dict(state_dict, strict=False) @@ -146,7 +80,8 @@ def register_u_resnet50(opt_net, opt): if __name__ == '__main__': - model = UResNet50(Bottleneck, [3,4,6,3]) + model = UResNet50_2(Bottleneck, [3,4,6,3]) samp = torch.rand(1,3,224,224) - model(samp) + y = model(samp) + print(y.shape) # For pixpro: attach to "tail.3" diff --git a/codes/scripts/extract_square_images.py b/codes/scripts/extract_square_images.py index fef469d4..e02114c2 100644 --- a/codes/scripts/extract_square_images.py +++ b/codes/scripts/extract_square_images.py @@ -14,16 +14,16 @@ def main(): split_img = False opt = {} opt['n_thread'] = 7 - opt['compression_level'] = 95 # JPEG compression quality rating. + opt['compression_level'] = 90 # JPEG compression quality rating. # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. opt['dest'] = 'file' opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\working'] - opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped' - opt['imgsize'] = 1024 - opt['bottom_crop'] = .1 - opt['keep_folder'] = True + opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised' + opt['imgsize'] = 256 + opt['bottom_crop'] = 0.1 + opt['keep_folder'] = False save_folder = opt['save_folder'] if not osp.exists(save_folder): @@ -58,7 +58,7 @@ class TiledDataset(data.Dataset): # Perform explicit crops first. These are generally used to get rid of watermarks so we dont even want to # consider these areas of the image. - if 'bottom_crop' in self.opt.keys(): + if 'bottom_crop' in self.opt.keys() and self.opt['bottom_crop'] > 0: bc = self.opt['bottom_crop'] if bc > 0 and bc < 1: bc = int(bc * img.shape[0]) @@ -83,9 +83,7 @@ class TiledDataset(data.Dataset): pts = os.path.split(pts[0]) output_folder = osp.join(self.opt['save_folder'], pts[-1]) os.makedirs(output_folder, exist_ok=True) - if not basename.endswith(".jpg"): - basename = basename + ".jpg" - cv2.imwrite(osp.join(output_folder, basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) + cv2.imwrite(osp.join(output_folder, basename), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) return None def __len__(self): diff --git a/codes/train.py b/codes/train.py index 59d821c9..34b1ae76 100644 --- a/codes/train.py +++ b/codes/train.py @@ -295,7 +295,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_pixpro_resnet.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imagenet_resnet50.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()