diff --git a/codes/models/pixel_level_contrastive_learning/__init__.py b/codes/models/pixel_level_contrastive_learning/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py b/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py new file mode 100644 index 00000000..9d13da31 --- /dev/null +++ b/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py @@ -0,0 +1,487 @@ +import math +import copy +import os +import random +from functools import wraps, partial +from math import floor + +import torch +import torchvision +from torch import nn, einsum +import torch.nn.functional as F + +from kornia import augmentation as augs +from kornia import filters, color + +from einops import rearrange + +# helper functions +from trainer.networks import register_model, create_model + + +def identity(t): + return t + +def default(val, def_val): + return def_val if val is None else val + +def rand_true(prob): + return random.random() < prob + +def singleton(cache_key): + def inner_fn(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + instance = getattr(self, cache_key) + if instance is not None: + return instance + + instance = fn(self, *args, **kwargs) + setattr(self, cache_key, instance) + return instance + return wrapper + return inner_fn + +def get_module_device(module): + return next(module.parameters()).device + +def set_requires_grad(model, val): + for p in model.parameters(): + p.requires_grad = val + +def cutout_coordinates(image, ratio_range = (0.6, 0.8)): + _, _, orig_h, orig_w = image.shape + + ratio_lo, ratio_hi = ratio_range + random_ratio = ratio_lo + random.random() * (ratio_hi - ratio_lo) + w, h = floor(random_ratio * orig_w), floor(random_ratio * orig_h) + coor_x = floor((orig_w - w) * random.random()) + coor_y = floor((orig_h - h) * random.random()) + return ((coor_y, coor_y + h), (coor_x, coor_x + w)), random_ratio + +def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'): + shape = image.shape + output_size = default(output_size, shape[2:]) + (y0, y1), (x0, x1) = coordinates + cutout_image = image[:, :, y0:y1, x0:x1] + return F.interpolate(cutout_image, size = output_size, mode = mode) + +# augmentation utils + +class RandomApply(nn.Module): + def __init__(self, fn, p): + super().__init__() + self.fn = fn + self.p = p + def forward(self, x): + if random.random() > self.p: + return x + return self.fn(x) + +# exponential moving average + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + +def update_moving_average(ema_updater, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = ema_updater.update_average(old_weight, up_weight) + +# loss fn + +def loss_fn(x, y): + x = F.normalize(x, dim=-1, p=2) + y = F.normalize(y, dim=-1, p=2) + return 2 - 2 * (x * y).sum(dim=-1) + +# classes + +class MLP(nn.Module): + def __init__(self, chan, chan_out = 256, inner_dim = 2048): + super().__init__() + self.net = nn.Sequential( + nn.Linear(chan, inner_dim), + nn.BatchNorm1d(inner_dim), + nn.ReLU(), + nn.Linear(inner_dim, chan_out) + ) + + def forward(self, x): + return self.net(x) + +class ConvMLP(nn.Module): + def __init__(self, chan, chan_out = 256, inner_dim = 2048): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(chan, inner_dim, 1), + nn.BatchNorm2d(inner_dim), + nn.ReLU(), + nn.Conv2d(inner_dim, chan_out, 1) + ) + + def forward(self, x): + return self.net(x) + +class PPM(nn.Module): + def __init__( + self, + *, + chan, + num_layers = 1, + gamma = 2): + super().__init__() + self.gamma = gamma + + if num_layers == 0: + self.transform_net = nn.Identity() + elif num_layers == 1: + self.transform_net = nn.Conv2d(chan, chan, 1) + elif num_layers == 2: + self.transform_net = nn.Sequential( + nn.Conv2d(chan, chan, 1), + nn.BatchNorm2d(chan), + nn.ReLU(), + nn.Conv2d(chan, chan, 1) + ) + else: + raise ValueError('num_layers must be one of 0, 1, or 2') + + def forward(self, x): + xi = x[:, :, :, :, None, None] + xj = x[:, :, None, None, :, :] + similarity = F.relu(F.cosine_similarity(xi, xj, dim = 1)) ** self.gamma + + transform_out = self.transform_net(x) + out = einsum('b x y h w, b c h w -> b c x y', similarity, transform_out) + return out + +# a wrapper class for the base neural network +# will manage the interception of the hidden layer output +# and pipe it into the projecter and predictor nets + +class NetWrapper(nn.Module): + def __init__( + self, + *, + net, + projection_size, + projection_hidden_size, + layer_pixel = -2, + layer_instance = -2 + ): + super().__init__() + self.net = net + self.layer_pixel = layer_pixel + self.layer_instance = layer_instance + + self.pixel_projector = None + self.instance_projector = None + + self.projection_size = projection_size + self.projection_hidden_size = projection_hidden_size + + self.hidden_pixel = None + self.hidden_instance = None + self.hook_registered = False + + def _find_layer(self, layer_id): + if type(layer_id) == str: + modules = dict([*self.net.named_modules()]) + return modules.get(layer_id, None) + elif type(layer_id) == int: + children = [*self.net.children()] + return children[layer_id] + return None + + def _hook(self, attr_name, _, __, output): + setattr(self, attr_name, output) + + def _register_hook(self): + pixel_layer = self._find_layer(self.layer_pixel) + instance_layer = self._find_layer(self.layer_instance) + + assert pixel_layer is not None, f'hidden layer ({self.layer_pixel}) not found' + assert instance_layer is not None, f'hidden layer ({self.layer_instance}) not found' + + pixel_layer.register_forward_hook(partial(self._hook, 'hidden_pixel')) + instance_layer.register_forward_hook(partial(self._hook, 'hidden_instance')) + self.hook_registered = True + + @singleton('pixel_projector') + def _get_pixel_projector(self, hidden): + _, dim, *_ = hidden.shape + projector = ConvMLP(dim, self.projection_size, self.projection_hidden_size) + return projector.to(hidden) + + @singleton('instance_projector') + def _get_instance_projector(self, hidden): + _, dim = hidden.shape + projector = MLP(dim, self.projection_size, self.projection_hidden_size) + return projector.to(hidden) + + def get_representation(self, x): + if not self.hook_registered: + self._register_hook() + + _ = self.net(x) + hidden_pixel = self.hidden_pixel + hidden_instance = self.hidden_instance + self.hidden_pixel = None + self.hidden_instance = None + assert hidden_pixel is not None, f'hidden pixel layer {self.layer_pixel} never emitted an output' + assert hidden_instance is not None, f'hidden instance layer {self.layer_instance} never emitted an output' + return hidden_pixel, hidden_instance + + def forward(self, x): + pixel_representation, instance_representation = self.get_representation(x) + instance_representation = instance_representation.flatten(1) + + pixel_projector = self._get_pixel_projector(pixel_representation) + instance_projector = self._get_instance_projector(instance_representation) + + pixel_projection = pixel_projector(pixel_representation) + instance_projection = instance_projector(instance_representation) + return pixel_projection, instance_projection + +# main class + +class PixelCL(nn.Module): + def __init__( + self, + net, + image_size, + hidden_layer_pixel = -2, + hidden_layer_instance = -2, + projection_size = 256, + projection_hidden_size = 2048, + augment_fn = None, + augment_fn2 = None, + prob_rand_hflip = 0.25, + moving_average_decay = 0.99, + ppm_num_layers = 1, + ppm_gamma = 2, + distance_thres = 0.7, + similarity_temperature = 0.3, + alpha = 1., + use_pixpro = True, + cutout_ratio_range = (0.6, 0.8), + cutout_interpolate_mode = 'nearest', + coord_cutout_interpolate_mode = 'bilinear' + ): + super().__init__() + + DEFAULT_AUG = nn.Sequential( + RandomApply(augs.ColorJitter(0.3, 0.3, 0.3, 0.2), p=0.8), + augs.RandomGrayscale(p=0.2), + RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1) + ) + + self.augment1 = default(augment_fn, DEFAULT_AUG) + self.augment2 = default(augment_fn2, self.augment1) + self.prob_rand_hflip = prob_rand_hflip + + self.online_encoder = NetWrapper( + net = net, + projection_size = projection_size, + projection_hidden_size = projection_hidden_size, + layer_pixel = hidden_layer_pixel, + layer_instance = hidden_layer_instance + ) + + self.target_encoder = None + self.target_ema_updater = EMA(moving_average_decay) + + self.distance_thres = distance_thres + self.similarity_temperature = similarity_temperature + self.alpha = alpha + + self.use_pixpro = use_pixpro + + if use_pixpro: + self.propagate_pixels = PPM( + chan = projection_size, + num_layers = ppm_num_layers, + gamma = ppm_gamma + ) + + self.cutout_ratio_range = cutout_ratio_range + self.cutout_interpolate_mode = cutout_interpolate_mode + self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode + + # instance level predictor + self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) + + # get device of network and make wrapper same device + device = get_module_device(net) + self.to(device) + + # send a mock image tensor to instantiate singleton parameters + self.forward(torch.randn(2, 3, image_size, image_size, device=device)) + + @singleton('target_encoder') + def _get_target_encoder(self): + target_encoder = copy.deepcopy(self.online_encoder) + set_requires_grad(target_encoder, False) + return target_encoder + + def reset_moving_average(self): + del self.target_encoder + self.target_encoder = None + + def update_moving_average(self): + assert self.target_encoder is not None, 'target encoder has not been created yet' + update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) + + def forward(self, x): + shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip + + rand_flip_fn = lambda t: torch.flip(t, dims = (-1,)) + + flip_image_one, flip_image_two = rand_true(prob_flip), rand_true(prob_flip) + flip_image_one_fn = rand_flip_fn if flip_image_one else identity + flip_image_two_fn = rand_flip_fn if flip_image_two else identity + + cutout_coordinates_one, _ = cutout_coordinates(x, self.cutout_ratio_range) + cutout_coordinates_two, _ = cutout_coordinates(x, self.cutout_ratio_range) + + image_one_cutout = cutout_and_resize(x, cutout_coordinates_one, mode = self.cutout_interpolate_mode) + image_two_cutout = cutout_and_resize(x, cutout_coordinates_two, mode = self.cutout_interpolate_mode) + + image_one_cutout = flip_image_one_fn(image_one_cutout) + image_two_cutout = flip_image_two_fn(image_two_cutout) + + image_one_cutout, image_two_cutout = self.augment1(image_one_cutout), self.augment2(image_two_cutout) + + self.aug1 = image_one_cutout.detach().clone() + self.aug2 = image_two_cutout.detach().clone() + + proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout) + proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout) + + image_h, image_w = shape[2:] + + proj_image_shape = proj_pixel_one.shape[2:] + proj_image_h, proj_image_w = proj_image_shape + + coordinates = torch.meshgrid( + torch.arange(image_h, device = device), + torch.arange(image_w, device = device) + ) + + coordinates = torch.stack(coordinates).unsqueeze(0).float() + coordinates /= math.sqrt(image_h ** 2 + image_w ** 2) + coordinates[:, 0] *= proj_image_h + coordinates[:, 1] *= proj_image_w + + proj_coors_one = cutout_and_resize(coordinates, cutout_coordinates_one, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode) + proj_coors_two = cutout_and_resize(coordinates, cutout_coordinates_two, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode) + + proj_coors_one = flip_image_one_fn(proj_coors_one) + proj_coors_two = flip_image_two_fn(proj_coors_two) + + proj_coors_one, proj_coors_two = map(lambda t: rearrange(t, 'b c h w -> (b h w) c'), (proj_coors_one, proj_coors_two)) + pdist = nn.PairwiseDistance(p = 2) + + num_pixels = proj_coors_one.shape[0] + + proj_coors_one_expanded = proj_coors_one[:, None].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2) + proj_coors_two_expanded = proj_coors_two[None, :].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2) + + distance_matrix = pdist(proj_coors_one_expanded, proj_coors_two_expanded) + distance_matrix = distance_matrix.reshape(num_pixels, num_pixels) + + positive_mask_one_two = distance_matrix < self.distance_thres + positive_mask_two_one = positive_mask_one_two.t() + + with torch.no_grad(): + target_encoder = self._get_target_encoder() + target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout) + target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout) + + # flatten all the pixel projections + + flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)') + + target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two))) + + # get total number of positive pixel pairs + + positive_pixel_pairs = positive_mask_one_two.sum() + + # get instance level loss + + pred_instance_one = self.online_predictor(proj_instance_one) + pred_instance_two = self.online_predictor(proj_instance_two) + + loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach()) + loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach()) + + instance_loss = (loss_instance_one + loss_instance_two).mean() + + if positive_pixel_pairs == 0: + return instance_loss, 0 + + if not self.use_pixpro: + # calculate pix contrast loss + + proj_pixel_one, proj_pixel_two = list(map(flatten, (proj_pixel_one, proj_pixel_two))) + + similarity_one_two = F.cosine_similarity(proj_pixel_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) / self.similarity_temperature + similarity_two_one = F.cosine_similarity(proj_pixel_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) / self.similarity_temperature + + loss_pix_one_two = -torch.log( + similarity_one_two.masked_select(positive_mask_one_two[None, ...]).exp().sum() / + similarity_one_two.exp().sum() + ) + + loss_pix_two_one = -torch.log( + similarity_two_one.masked_select(positive_mask_two_one[None, ...]).exp().sum() / + similarity_two_one.exp().sum() + ) + + pix_loss = (loss_pix_one_two + loss_pix_two_one) / 2 + else: + # calculate pix pro loss + + propagated_pixels_one = self.propagate_pixels(proj_pixel_one) + propagated_pixels_two = self.propagate_pixels(proj_pixel_two) + + propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two))) + + propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) + propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) + + loss_pixpro_one_two = - propagated_similarity_one_two.masked_select(positive_mask_one_two[None, ...]).mean() + loss_pixpro_two_one = - propagated_similarity_two_one.masked_select(positive_mask_two_one[None, ...]).mean() + + pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2 + + # total loss + + loss = pix_loss * self.alpha + instance_loss + return loss, positive_pixel_pairs + + # Allows visualizing what the augmentor is up to. + def visual_dbg(self, step, path): + if not hasattr(self, 'aug1'): + return + torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,))) + torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,))) + + +@register_model +def register_pixel_contrastive_learner(opt_net, opt): + subnet = create_model(opt, opt_net['subnet']) + kwargs = opt_net['kwargs'] + if 'subnet_pretrain_path' in opt_net.keys(): + sd = torch.load(opt_net['subnet_pretrain_path']) + subnet.load_state_dict(sd, strict=False) + return PixelCL(subnet, **kwargs) diff --git a/codes/models/pixel_level_contrastive_learning/resnet_unet.py b/codes/models/pixel_level_contrastive_learning/resnet_unet.py new file mode 100644 index 00000000..e9fdbfaa --- /dev/null +++ b/codes/models/pixel_level_contrastive_learning/resnet_unet.py @@ -0,0 +1,153 @@ +# Resnet implementation that adds a u-net style up-conversion component to output values at a +# specified pixel density. +# +# The downsampling part of the network is compatible with the built-in torch resnet for use in +# transfer learning. +# +# Only resnet50 currently supported. + +import torch +import torch.nn as nn +from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3 +from torchvision.models.utils import load_state_dict_from_url +import torchvision + + +from trainer.networks import register_model +from utils.util import checkpoint + +model_urls = { + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', +} + + +class ReverseBottleneck(nn.Module): + + def __init__(self, inplanes, planes, groups=1, passthrough=False, + base_width=64, dilation=1, norm_layer=None): + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + self.passthrough = passthrough + if passthrough: + self.integrate = conv1x1(inplanes*2, inplanes) + self.bn_integrate = norm_layer(inplanes) + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, groups, dilation) + self.bn2 = norm_layer(width) + self.residual_upsample = nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + conv1x1(width, width), + norm_layer(width), + ) + self.conv3 = conv1x1(width, planes) + self.bn3 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.upsample = nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + conv1x1(inplanes, planes), + norm_layer(planes), + ) + + def forward(self, x, passthrough=None): + if self.passthrough: + x = self.bn_integrate(self.integrate(torch.cat([x, passthrough], dim=1))) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.residual_upsample(out) + + out = self.conv3(out) + out = self.bn3(out) + + identity = self.upsample(x) + + out = out + identity + out = self.relu(out) + + return out + + +class UResNet50(torchvision.models.resnet.ResNet): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, + replace_stride_with_dilation, norm_layer) + if norm_layer is None: + norm_layer = nn.BatchNorm2d + ''' + # For reference: + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + ''' + uplayers = [] + inplanes = 2048 + first = True + for i in range(2): + uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first)) + inplanes = inplanes // 2 + first = False + self.uplayers = nn.ModuleList(uplayers) + self.tail = nn.Sequential(conv1x1(1024, 512), + norm_layer(512), + nn.ReLU(), + conv3x3(512, 512), + norm_layer(512), + nn.ReLU(), + conv1x1(512, 128)) + + del self.fc # Not used in this implementation and just consumes a ton of GPU memory. + + + def _forward_impl(self, x): + # Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl, + # except using checkpoints on the body conv layers. + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x1 = checkpoint(self.layer1, x) + x2 = checkpoint(self.layer2, x1) + x3 = checkpoint(self.layer3, x2) + x4 = checkpoint(self.layer4, x3) + unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused. + + x = checkpoint(self.uplayers[0], x4) + x = checkpoint(self.uplayers[1], x, x3) + #x = checkpoint(self.uplayers[2], x, x2) + #x = checkpoint(self.uplayers[3], x, x1) + + return checkpoint(self.tail, torch.cat([x, x2], dim=1)) + + def forward(self, x): + return self._forward_impl(x) + + +@register_model +def register_u_resnet50(opt_net, opt): + model = UResNet50(Bottleneck, [3, 4, 6, 3]) + return model + + +if __name__ == '__main__': + model = UResNet50(Bottleneck, [3,4,6,3]) + samp = torch.rand(1,3,224,224) + model(samp) + # For pixpro: attach to "tail.3" diff --git a/codes/scripts/byol_extract_wrapped_model.py b/codes/scripts/byol_extract_wrapped_model.py index 0b5147c4..65652f16 100644 --- a/codes/scripts/byol_extract_wrapped_model.py +++ b/codes/scripts/byol_extract_wrapped_model.py @@ -3,8 +3,8 @@ import torch from models.spinenet_arch import SpineNet if __name__ == '__main__': - pretrained_path = '../../experiments/byol_discriminator.pth' - output_path = '../../experiments/byol_discriminator_extracted.pth' + pretrained_path = '../../experiments/resnet_byol_diffframe_115k.pth' + output_path = '../../experiments/resnet_byol_diffframe_115k_.pth' wrap_key = 'online_encoder.net.' sd = torch.load(pretrained_path) diff --git a/codes/scripts/extract_subimages_with_ref.py b/codes/scripts/extract_subimages_with_ref.py index c6dd1182..a7df57c3 100644 --- a/codes/scripts/extract_subimages_with_ref.py +++ b/codes/scripts/extract_subimages_with_ref.py @@ -19,13 +19,13 @@ def main(): # compression time. If read raw images during training, use 0 for faster IO speed. opt['dest'] = 'file' - opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\images' - opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\512_with_ref_new' - opt['crop_sz'] = [1024, 2048] # the size of each sub-image - opt['step'] = [700, 1200] # step of the sliding crop window - opt['exclusions'] = [[],[],[]] # image names matching these terms wont be included in the processing. - opt['thres_sz'] = 256 # size threshold - opt['resize_final_img'] = [.5, .25] + opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new' + opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_v5' + opt['crop_sz'] = [256, 512] # the size of each sub-image + opt['step'] = [256, 512] # step of the sliding crop window + opt['exclusions'] = [[],[]] # image names matching these terms wont be included in the processing. + opt['thres_sz'] = 129 # size threshold + opt['resize_final_img'] = [1, .5] opt['only_resize'] = False opt['vertical_split'] = False opt['input_image_max_size_before_being_halved'] = 5500 # As described, images larger than this dimensional size will be halved before anything else is done. diff --git a/codes/train.py b/codes/train.py index 6d5ca881..2cef4362 100644 --- a/codes/train.py +++ b/codes/train.py @@ -295,7 +295,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_xxfaces_styled_sr/train_xxfaces_styled_sr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_pixpro_resnet.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()