diff --git a/codes/models/global_convs/gc_resnet.py b/codes/models/global_convs/gc_resnet.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/models/pixel_level_contrastive_learning/__init__.py b/codes/models/pixel_level_contrastive_learning/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py b/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py deleted file mode 100644 index 75c28a46..00000000 --- a/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py +++ /dev/null @@ -1,507 +0,0 @@ -import math -import copy -import os -import random -from functools import wraps, partial -from math import floor - -import torch -import torchvision -from torch import nn, einsum -import torch.nn.functional as F - -from kornia import augmentation as augs -from kornia import filters, color - -from einops import rearrange - -# helper functions -from trainer.networks import register_model, create_model - - -def identity(t): - return t - -def default(val, def_val): - return def_val if val is None else val - -def rand_true(prob): - return random.random() < prob - -def singleton(cache_key): - def inner_fn(fn): - @wraps(fn) - def wrapper(self, *args, **kwargs): - instance = getattr(self, cache_key) - if instance is not None: - return instance - - instance = fn(self, *args, **kwargs) - setattr(self, cache_key, instance) - return instance - return wrapper - return inner_fn - -def get_module_device(module): - return next(module.parameters()).device - -def set_requires_grad(model, val): - for p in model.parameters(): - p.requires_grad = val - -def cutout_coordinates(image, ratio_range = (0.6, 0.8)): - _, _, orig_h, orig_w = image.shape - - ratio_lo, ratio_hi = ratio_range - random_ratio = ratio_lo + random.random() * (ratio_hi - ratio_lo) - w, h = floor(random_ratio * orig_w), floor(random_ratio * orig_h) - coor_x = floor((orig_w - w) * random.random()) - coor_y = floor((orig_h - h) * random.random()) - return ((coor_y, coor_y + h), (coor_x, coor_x + w)), random_ratio - -def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'): - shape = image.shape - output_size = default(output_size, shape[2:]) - (y0, y1), (x0, x1) = coordinates - cutout_image = image[:, :, y0:y1, x0:x1] - return F.interpolate(cutout_image, size = output_size, mode = mode) - -def scale_coords(coords, scale): - output = [[0,0],[0,0]] - for j in range(2): - for k in range(2): - output[j][k] = int(coords[j][k] / scale) - return output - -def reverse_cutout_and_resize(image, coordinates, scale_reduction, mode = 'nearest'): - blank = torch.zeros_like(image) - coordinates = scale_coords(coordinates, scale_reduction) - (y0, y1), (x0, x1) = coordinates - orig_cutout_shape = (y1-y0, x1-x0) - if orig_cutout_shape[0] <= 0 or orig_cutout_shape[1] <= 0: - return None - - un_resized_img = F.interpolate(image, size=orig_cutout_shape, mode=mode) - blank[:,:,y0:y1,x0:x1] = un_resized_img - return blank - -def compute_shared_coords(coords1, coords2, scale_reduction): - (y1_t, y1_b), (x1_l, x1_r) = scale_coords(coords1, scale_reduction) - (y2_t, y2_b), (x2_l, x2_r) = scale_coords(coords2, scale_reduction) - shared = ((max(y1_t, y2_t), min(y1_b, y2_b)), - (max(x1_l, x2_l), min(x1_r, x2_r))) - for s in shared: - if s == 0: - return None - return shared - -def get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, img_orig_shape, interp_mode): - # Unflip the pixel projections - proj_pixel_one = flip_image_one_fn(proj_pixel_one) - proj_pixel_two = flip_image_two_fn(proj_pixel_two) - - # Undo the cutout and resize, taking into account the scale reduction applied by the encoder. - scale_reduction = proj_pixel_one.shape[-1] / img_orig_shape[-1] - proj_pixel_one = reverse_cutout_and_resize(proj_pixel_one, cutout_coordinates_one, scale_reduction, - mode=interp_mode) - proj_pixel_two = reverse_cutout_and_resize(proj_pixel_two, cutout_coordinates_two, scale_reduction, - mode=interp_mode) - if proj_pixel_one is None or proj_pixel_two is None: - print("Could not extract projected image region. The selected cutout coordinates were smaller than the aggregate size of one latent block!") - return None - - # Compute the shared coordinates for the two cutouts: - shared_coords = compute_shared_coords(cutout_coordinates_one, cutout_coordinates_two, scale_reduction) - if shared_coords is None: - print("No shared coordinates for this iteration (probably should just recompute those coordinates earlier..") - return None - (yt, yb), (xl, xr) = shared_coords - - return proj_pixel_one[:, :, yt:yb, xl:xr], proj_pixel_two[:, :, yt:yb, xl:xr] - -# augmentation utils - -class RandomApply(nn.Module): - def __init__(self, fn, p): - super().__init__() - self.fn = fn - self.p = p - def forward(self, x): - if random.random() > self.p: - return x - return self.fn(x) - -# exponential moving average - -class EMA(): - def __init__(self, beta): - super().__init__() - self.beta = beta - - def update_average(self, old, new): - if old is None: - return new - return old * self.beta + (1 - self.beta) * new - -def update_moving_average(ema_updater, ma_model, current_model): - for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): - old_weight, up_weight = ma_params.data, current_params.data - ma_params.data = ema_updater.update_average(old_weight, up_weight) - -# loss fn - -def loss_fn(x, y): - x = F.normalize(x, dim=-1, p=2) - y = F.normalize(y, dim=-1, p=2) - return 2 - 2 * (x * y).sum(dim=-1) - -# classes - -class MLP(nn.Module): - def __init__(self, chan, chan_out = 256, inner_dim = 2048): - super().__init__() - self.net = nn.Sequential( - nn.Linear(chan, inner_dim), - nn.BatchNorm1d(inner_dim), - nn.ReLU(), - nn.Linear(inner_dim, chan_out) - ) - - def forward(self, x): - return self.net(x) - -class ConvMLP(nn.Module): - def __init__(self, chan, chan_out = 256, inner_dim = 2048): - super().__init__() - self.net = nn.Sequential( - nn.Conv2d(chan, inner_dim, 1), - nn.BatchNorm2d(inner_dim), - nn.ReLU(), - nn.Conv2d(inner_dim, chan_out, 1) - ) - - def forward(self, x): - return self.net(x) - -class PPM(nn.Module): - def __init__( - self, - *, - chan, - num_layers = 1, - gamma = 2): - super().__init__() - self.gamma = gamma - - if num_layers == 0: - self.transform_net = nn.Identity() - elif num_layers == 1: - self.transform_net = nn.Conv2d(chan, chan, 1) - elif num_layers == 2: - self.transform_net = nn.Sequential( - nn.Conv2d(chan, chan, 1), - nn.BatchNorm2d(chan), - nn.ReLU(), - nn.Conv2d(chan, chan, 1) - ) - else: - raise ValueError('num_layers must be one of 0, 1, or 2') - - def forward(self, x): - xi = x[:, :, :, :, None, None] - xj = x[:, :, None, None, :, :] - similarity = F.relu(F.cosine_similarity(xi, xj, dim = 1)) ** self.gamma - - transform_out = self.transform_net(x) - out = einsum('b x y h w, b c h w -> b c x y', similarity, transform_out) - return out - -# a wrapper class for the base neural network -# will manage the interception of the hidden layer output -# and pipe it into the projecter and predictor nets - -class NetWrapper(nn.Module): - def __init__( - self, - *, - net, - instance_projection_size, - instance_projection_hidden_size, - pix_projection_size, - pix_projection_hidden_size, - layer_pixel = -2, - layer_instance = -2 - ): - super().__init__() - self.net = net - self.layer_pixel = layer_pixel - self.layer_instance = layer_instance - - self.pixel_projector = None - self.instance_projector = None - - self.instance_projection_size = instance_projection_size - self.instance_projection_hidden_size = instance_projection_hidden_size - self.pix_projection_size = pix_projection_size - self.pix_projection_hidden_size = pix_projection_hidden_size - - self.hidden_pixel = None - self.hidden_instance = None - self.hook_registered = False - - def _find_layer(self, layer_id): - if type(layer_id) == str: - modules = dict([*self.net.named_modules()]) - return modules.get(layer_id, None) - elif type(layer_id) == int: - children = [*self.net.children()] - return children[layer_id] - return None - - def _hook(self, attr_name, _, __, output): - setattr(self, attr_name, output) - - def _register_hook(self): - pixel_layer = self._find_layer(self.layer_pixel) - instance_layer = self._find_layer(self.layer_instance) - - assert pixel_layer is not None, f'hidden layer ({self.layer_pixel}) not found' - assert instance_layer is not None, f'hidden layer ({self.layer_instance}) not found' - - pixel_layer.register_forward_hook(partial(self._hook, 'hidden_pixel')) - instance_layer.register_forward_hook(partial(self._hook, 'hidden_instance')) - self.hook_registered = True - - @singleton('pixel_projector') - def _get_pixel_projector(self, hidden): - _, dim, *_ = hidden.shape - projector = ConvMLP(dim, self.pix_projection_size, self.pix_projection_hidden_size) - return projector.to(hidden) - - @singleton('instance_projector') - def _get_instance_projector(self, hidden): - _, dim = hidden.shape - projector = MLP(dim, self.instance_projection_size, self.instance_projection_hidden_size) - return projector.to(hidden) - - def get_representation(self, x): - if not self.hook_registered: - self._register_hook() - - _ = self.net(x) - hidden_pixel = self.hidden_pixel - hidden_instance = self.hidden_instance - self.hidden_pixel = None - self.hidden_instance = None - assert hidden_pixel is not None, f'hidden pixel layer {self.layer_pixel} never emitted an output' - assert hidden_instance is not None, f'hidden instance layer {self.layer_instance} never emitted an output' - return hidden_pixel, hidden_instance - - def forward(self, x): - pixel_representation, instance_representation = self.get_representation(x) - instance_representation = instance_representation.flatten(1) - - pixel_projector = self._get_pixel_projector(pixel_representation) - instance_projector = self._get_instance_projector(instance_representation) - - pixel_projection = pixel_projector(pixel_representation) - instance_projection = instance_projector(instance_representation) - return pixel_projection, instance_projection - -# main class -class PixelCL(nn.Module): - def __init__( - self, - net, - image_size, - hidden_layer_pixel = -2, - hidden_layer_instance = -2, - instance_projection_size = 256, - instance_projection_hidden_size = 2048, - pix_projection_size = 256, - pix_projection_hidden_size = 2048, - augment_fn = None, - augment_fn2 = None, - prob_rand_hflip = 0.25, - moving_average_decay = 0.99, - ppm_num_layers = 1, - ppm_gamma = 2, - distance_thres = 0.7, - similarity_temperature = 0.3, - cutout_ratio_range = (0.6, 0.8), - cutout_interpolate_mode = 'nearest', - coord_cutout_interpolate_mode = 'bilinear', - max_latent_dim = None # When set, this is the number of stochastically extracted pixels from the latent to extract. Must have an integer square root. - ): - super().__init__() - - DEFAULT_AUG = nn.Sequential( - RandomApply(augs.ColorJitter(0.6, 0.6, 0.6, 0.2), p=0.8), - augs.RandomGrayscale(p=0.2), - RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), - augs.RandomSolarize(p=0.5), - # Normalize left out because it should be done at the model level. - ) - - self.augment1 = default(augment_fn, DEFAULT_AUG) - self.augment2 = default(augment_fn2, self.augment1) - self.prob_rand_hflip = prob_rand_hflip - - self.online_encoder = NetWrapper( - net = net, - instance_projection_size = instance_projection_size, - instance_projection_hidden_size = instance_projection_hidden_size, - pix_projection_size = pix_projection_size, - pix_projection_hidden_size = pix_projection_hidden_size, - layer_pixel = hidden_layer_pixel, - layer_instance = hidden_layer_instance - ) - - self.target_encoder = None - self.target_ema_updater = EMA(moving_average_decay) - - self.distance_thres = distance_thres - self.similarity_temperature = similarity_temperature - - # This requirement is due to the way that these are processed, not a hard requirement. - assert math.sqrt(max_latent_dim) == int(math.sqrt(max_latent_dim)) - self.max_latent_dim = max_latent_dim - - self.propagate_pixels = PPM( - chan = pix_projection_size, - num_layers = ppm_num_layers, - gamma = ppm_gamma - ) - - self.cutout_ratio_range = cutout_ratio_range - self.cutout_interpolate_mode = cutout_interpolate_mode - self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode - - # instance level predictor - self.online_predictor = MLP(instance_projection_size, instance_projection_size, instance_projection_hidden_size) - - # get device of network and make wrapper same device - device = get_module_device(net) - self.to(device) - - # send a mock image tensor to instantiate singleton parameters - self.forward(torch.randn(2, 3, image_size, image_size, device=device)) - - @singleton('target_encoder') - def _get_target_encoder(self): - target_encoder = copy.deepcopy(self.online_encoder) - set_requires_grad(target_encoder, False) - return target_encoder - - def reset_moving_average(self): - del self.target_encoder - self.target_encoder = None - - def update_moving_average(self): - assert self.target_encoder is not None, 'target encoder has not been created yet' - update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) - - def forward(self, x): - shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip - - rand_flip_fn = lambda t: torch.flip(t, dims = (-1,)) - - flip_image_one, flip_image_two = rand_true(prob_flip), rand_true(prob_flip) - flip_image_one_fn = rand_flip_fn if flip_image_one else identity - flip_image_two_fn = rand_flip_fn if flip_image_two else identity - - cutout_coordinates_one, _ = cutout_coordinates(x, self.cutout_ratio_range) - cutout_coordinates_two, _ = cutout_coordinates(x, self.cutout_ratio_range) - - image_one_cutout = cutout_and_resize(x, cutout_coordinates_one, mode = self.cutout_interpolate_mode) - image_two_cutout = cutout_and_resize(x, cutout_coordinates_two, mode = self.cutout_interpolate_mode) - - image_one_cutout = flip_image_one_fn(image_one_cutout) - image_two_cutout = flip_image_two_fn(image_two_cutout) - - image_one_cutout, image_two_cutout = self.augment1(image_one_cutout), self.augment2(image_two_cutout) - - self.aug1 = image_one_cutout.detach().clone() - self.aug2 = image_two_cutout.detach().clone() - - proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout) - proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout) - - proj_pixel_one, proj_pixel_two = get_shared_region(proj_pixel_one, proj_pixel_two, cutout_coordinates_one, - cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, - image_one_cutout.shape, self.cutout_interpolate_mode) - if proj_pixel_one is None or proj_pixel_two is None: - positive_pixel_pairs = 0 - else: - positive_pixel_pairs = proj_pixel_one.shape[-1] * proj_pixel_one.shape[-2] - - with torch.no_grad(): - target_encoder = self._get_target_encoder() - target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout) - target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout) - target_proj_pixel_one, target_proj_pixel_two = get_shared_region(target_proj_pixel_one, target_proj_pixel_two, cutout_coordinates_one, - cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, - image_one_cutout.shape, self.cutout_interpolate_mode) - - # If max_latent_dim is specified, stochastically extract latents from the shared areas. - b, c, pp_h, pp_w = proj_pixel_one.shape - if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim: - prob = torch.full((self.max_latent_dim,), 1 / (self.max_latent_dim)) - latents = [proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two] - extracted = [] - for l in latents: - l = l.reshape(b, c, pp_h * pp_w) - l = l[:, :, prob.multinomial(num_samples=self.max_latent_dim, replacement=False)] - # For compatibility with the existing pixpro code, reshape this stochastic sampling back into a 2d "square". - # Note that the actual structure no longer matters going forwards. Pixels are only compared to themselves and others without regards - # to the original image structure. - sqdim = int(math.sqrt(self.max_latent_dim)) - extracted.append(l.reshape(b, c, sqdim, sqdim)) - proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted - - # flatten all the pixel projections - flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)') - target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two))) - - # get instance level loss - pred_instance_one = self.online_predictor(proj_instance_one) - pred_instance_two = self.online_predictor(proj_instance_two) - loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach()) - loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach()) - instance_loss = (loss_instance_one + loss_instance_two).mean() - - if positive_pixel_pairs == 0: - return instance_loss, 0 - - # calculate pix pro loss - propagated_pixels_one = self.propagate_pixels(proj_pixel_one) - propagated_pixels_two = self.propagate_pixels(proj_pixel_two) - - propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two))) - - propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) - propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) - - loss_pixpro_one_two = - propagated_similarity_one_two.mean() - loss_pixpro_two_one = - propagated_similarity_two_one.mean() - - pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2 - - return instance_loss, pix_loss, positive_pixel_pairs - - # Allows visualizing what the augmentor is up to. - def visual_dbg(self, step, path): - if not hasattr(self, 'aug1'): - return - torchvision.utils.save_image(self.aug1, os.path.join(path, "%i_aug1.png" % (step,))) - torchvision.utils.save_image(self.aug2, os.path.join(path, "%i_aug2.png" % (step,))) - - -@register_model -def register_pixel_contrastive_learner(opt_net, opt): - subnet = create_model(opt, opt_net['subnet']) - kwargs = opt_net['kwargs'] - if 'subnet_pretrain_path' in opt_net.keys(): - sd = torch.load(opt_net['subnet_pretrain_path']) - subnet.load_state_dict(sd, strict=False) - return PixelCL(subnet, **kwargs) diff --git a/codes/models/pixel_level_contrastive_learning/resnet_unet.py b/codes/models/pixel_level_contrastive_learning/resnet_unet.py deleted file mode 100644 index 46bc747f..00000000 --- a/codes/models/pixel_level_contrastive_learning/resnet_unet.py +++ /dev/null @@ -1,152 +0,0 @@ -# Resnet implementation that adds a u-net style up-conversion component to output values at a -# specified pixel density. -# -# The downsampling part of the network is compatible with the built-in torch resnet for use in -# transfer learning. -# -# Only resnet50 currently supported. - -import torch -import torch.nn as nn -from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3 -from torchvision.models.utils import load_state_dict_from_url -import torchvision - - -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -class ReverseBottleneck(nn.Module): - - def __init__(self, inplanes, planes, groups=1, passthrough=False, - base_width=64, dilation=1, norm_layer=None): - super().__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups - self.passthrough = passthrough - if passthrough: - self.integrate = conv1x1(inplanes*2, inplanes) - self.bn_integrate = norm_layer(inplanes) - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, groups, dilation) - self.bn2 = norm_layer(width) - self.residual_upsample = nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), - conv1x1(width, width), - norm_layer(width), - ) - self.conv3 = conv1x1(width, planes) - self.bn3 = norm_layer(planes) - self.relu = nn.ReLU(inplace=True) - self.upsample = nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), - conv1x1(inplanes, planes), - norm_layer(planes), - ) - - def forward(self, x, passthrough=None): - if self.passthrough: - x = self.bn_integrate(self.integrate(torch.cat([x, passthrough], dim=1))) - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.residual_upsample(out) - - out = self.conv3(out) - out = self.bn3(out) - - identity = self.upsample(x) - - out = out + identity - out = self.relu(out) - - return out - - -class UResNet50(torchvision.models.resnet.ResNet): - - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, - norm_layer=None, out_dim=128): - super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, - replace_stride_with_dilation, norm_layer) - if norm_layer is None: - norm_layer = nn.BatchNorm2d - ''' - # For reference: - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - dilate=replace_stride_with_dilation[2]) - ''' - uplayers = [] - inplanes = 2048 - first = True - for i in range(2): - uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first)) - inplanes = inplanes // 2 - first = False - self.uplayers = nn.ModuleList(uplayers) - self.tail = nn.Sequential(conv1x1(1024, 512), - norm_layer(512), - nn.ReLU(), - conv3x3(512, 512), - norm_layer(512), - nn.ReLU(), - conv1x1(512, out_dim)) - - del self.fc # Not used in this implementation and just consumes a ton of GPU memory. - - - def _forward_impl(self, x): - # Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl, - # except using checkpoints on the body conv layers. - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x1 = checkpoint(self.layer1, x) - x2 = checkpoint(self.layer2, x1) - x3 = checkpoint(self.layer3, x2) - x4 = checkpoint(self.layer4, x3) - unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused. - - x = checkpoint(self.uplayers[0], x4) - x = checkpoint(self.uplayers[1], x, x3) - #x = checkpoint(self.uplayers[2], x, x2) - #x = checkpoint(self.uplayers[3], x, x1) - - return checkpoint(self.tail, torch.cat([x, x2], dim=1)) - - def forward(self, x): - return self._forward_impl(x) - - -@register_model -def register_u_resnet50(opt_net, opt): - model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim']) - if opt_get(opt_net, ['use_pretrained_base'], False): - state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True) - model.load_state_dict(state_dict, strict=False) - return model - - -if __name__ == '__main__': - model = UResNet50(Bottleneck, [3,4,6,3]) - samp = torch.rand(1,3,224,224) - model(samp) - # For pixpro: attach to "tail.3" diff --git a/codes/models/pixel_level_contrastive_learning/resnet_unet_2.py b/codes/models/pixel_level_contrastive_learning/resnet_unet_2.py deleted file mode 100644 index dd7ed05c..00000000 --- a/codes/models/pixel_level_contrastive_learning/resnet_unet_2.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -import torch.nn as nn -from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3 -from torchvision.models.utils import load_state_dict_from_url -import torchvision - -from models.arch_util import ConvBnRelu -from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -class UResNet50_2(torchvision.models.resnet.ResNet): - - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, - norm_layer=None, out_dim=128): - super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, - replace_stride_with_dilation, norm_layer) - if norm_layer is None: - norm_layer = nn.BatchNorm2d - self.level_conv = ConvBnRelu(3, 64) - ''' - # For reference: - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - dilate=replace_stride_with_dilation[2]) - ''' - uplayers = [] - inplanes = 2048 - first = True - div = [2,2,2,4,1] - for i in range(5): - uplayers.append(ReverseBottleneck(inplanes, inplanes // div[i], norm_layer=norm_layer, passthrough=not first)) - inplanes = inplanes // div[i] - first = False - self.uplayers = nn.ModuleList(uplayers) - self.tail = nn.Sequential(conv3x3(128, 64), - norm_layer(64), - nn.ReLU(), - conv1x1(64, out_dim)) - - del self.fc # Not used in this implementation and just consumes a ton of GPU memory. - - - def _forward_impl(self, x): - level = self.level_conv(x) - x0 = self.relu(self.bn1(self.conv1(x))) - x = self.maxpool(x0) - - x1 = checkpoint(self.layer1, x) - x2 = checkpoint(self.layer2, x1) - x3 = checkpoint(self.layer3, x2) - x4 = checkpoint(self.layer4, x3) - unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused. - - x = checkpoint(self.uplayers[0], x4) - x = checkpoint(self.uplayers[1], x, x3) - x = checkpoint(self.uplayers[2], x, x2) - x = checkpoint(self.uplayers[3], x, x1) - x = checkpoint(self.uplayers[4], x, x0) - - return checkpoint(self.tail, torch.cat([x, level], dim=1)) - - def forward(self, x): - return self._forward_impl(x) - - -@register_model -def register_u_resnet50_2(opt_net, opt): - model = UResNet50_2(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim']) - if opt_get(opt_net, ['use_pretrained_base'], False): - state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True) - model.load_state_dict(state_dict, strict=False) - return model - - -if __name__ == '__main__': - model = UResNet50_2(Bottleneck, [3,4,6,3]) - samp = torch.rand(1,3,224,224) - y = model(samp) - print(y.shape) - # For pixpro: attach to "tail.3" diff --git a/codes/models/pixel_level_contrastive_learning/resnet_unet_3.py b/codes/models/pixel_level_contrastive_learning/resnet_unet_3.py deleted file mode 100644 index 49c87036..00000000 --- a/codes/models/pixel_level_contrastive_learning/resnet_unet_3.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import torch.nn as nn -from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3 -from torchvision.models.utils import load_state_dict_from_url -import torchvision - -from models.arch_util import ConvBnRelu -from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -class UResNet50_3(torchvision.models.resnet.ResNet): - - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, - norm_layer=None, out_dim=128): - super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, - replace_stride_with_dilation, norm_layer) - if norm_layer is None: - norm_layer = nn.BatchNorm2d - ''' - # For reference: - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - dilate=replace_stride_with_dilation[2]) - ''' - uplayers = [] - inplanes = 2048 - first = True - for i in range(3): - uplayers.append(ReverseBottleneck(inplanes, inplanes // 2, norm_layer=norm_layer, passthrough=not first)) - inplanes = inplanes // 2 - first = False - self.uplayers = nn.ModuleList(uplayers) - - # These two variables are separated out and renamed so that I can re-use parameters from a pretrained resnet_unet2. - self.last_uplayer = ReverseBottleneck(256, 128, norm_layer=norm_layer, passthrough=True) - self.tail3 = nn.Sequential(conv1x1(192, 128), - norm_layer(128), - nn.ReLU(), - conv1x1(128, out_dim)) - - del self.fc # Not used in this implementation and just consumes a ton of GPU memory. - - - def _forward_impl(self, x): - x0 = self.relu(self.bn1(self.conv1(x))) - x = self.maxpool(x0) - - x1 = checkpoint(self.layer1, x) - x2 = checkpoint(self.layer2, x1) - x3 = checkpoint(self.layer3, x2) - x4 = checkpoint(self.layer4, x3) - unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused. - - x = checkpoint(self.uplayers[0], x4) - x = checkpoint(self.uplayers[1], x, x3) - x = checkpoint(self.uplayers[2], x, x2) - x = checkpoint(self.last_uplayer, x, x1) - - return checkpoint(self.tail3, torch.cat([x, x0], dim=1)) - - def forward(self, x): - return self._forward_impl(x) - - -@register_model -def register_u_resnet50_3(opt_net, opt): - model = UResNet50_3(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim']) - if opt_get(opt_net, ['use_pretrained_base'], False): - state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True) - model.load_state_dict(state_dict, strict=False) - return model - - -if __name__ == '__main__': - model = UResNet50_3(Bottleneck, [3,4,6,3]) - samp = torch.rand(1,3,224,224) - y = model(samp) - print(y.shape) - # For pixpro: attach to "tail.3" diff --git a/codes/models/styled_sr/__init__.py b/codes/models/styled_sr/__init__.py deleted file mode 100644 index 1789cfbb..00000000 --- a/codes/models/styled_sr/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from models.styled_sr.discriminator import StyleSrGanDivergenceLoss - - -def create_stylesr_loss(opt_loss, env): - type = opt_loss['type'] - if type == 'style_sr_gan_divergence_loss': - return StyleSrGanDivergenceLoss(opt_loss, env) - else: - raise NotImplementedError diff --git a/codes/models/styled_sr/discriminator.py b/codes/models/styled_sr/discriminator.py deleted file mode 100644 index 44fd83f6..00000000 --- a/codes/models/styled_sr/discriminator.py +++ /dev/null @@ -1,344 +0,0 @@ -# Heavily based on the lucidrains stylegan2 discriminator implementation. -import math -import os -from functools import partial -from math import log2 -from random import random - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision -from torch.autograd import grad as torch_grad -import trainer.losses as L -from vector_quantize_pytorch import VectorQuantize - -from models.styled_sr.stylegan2_base import attn_and_ff, PermuteToFrom, Blur, leaky_relu, exists -from models.styled_sr.transfer_primitives import TransferConv2d, TransferLinear -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -class DiscriminatorBlock(nn.Module): - def __init__(self, input_channels, filters, downsample=True, transfer_mode=False): - super().__init__() - self.filters = filters - self.conv_res = TransferConv2d(input_channels, filters, 1, stride=(2 if downsample else 1), transfer_mode=transfer_mode) - - self.net = nn.Sequential( - TransferConv2d(input_channels, filters, 3, padding=1, transfer_mode=transfer_mode), - leaky_relu(), - TransferConv2d(filters, filters, 3, padding=1, transfer_mode=transfer_mode), - leaky_relu() - ) - - self.downsample = nn.Sequential( - Blur(), - TransferConv2d(filters, filters, 3, padding=1, stride=2, transfer_mode=transfer_mode) - ) if downsample else None - - def forward(self, x): - res = self.conv_res(x) - x = self.net(x) - if exists(self.downsample): - x = self.downsample(x) - x = (x + res) * (1 / math.sqrt(2)) - return x - - -class StyleSrDiscriminator(nn.Module): - def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], - transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False, - transfer_mode=False): - super().__init__() - num_layers = int(log2(image_size) - 1) - - blocks = [] - filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)] - - set_fmap_max = partial(min, fmap_max) - filters = list(map(set_fmap_max, filters)) - chan_in_out = list(zip(filters[:-1], filters[1:])) - - blocks = [] - attn_blocks = [] - quantize_blocks = [] - - for ind, (in_chan, out_chan) in enumerate(chan_in_out): - num_layer = ind + 1 - is_not_last = ind != (len(chan_in_out) - 1) - - block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode) - blocks.append(block) - - attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None - - attn_blocks.append(attn_fn) - - if quantize: - quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None - quantize_blocks.append(quantize_fn) - else: - quantize_blocks.append(None) - - self.blocks = nn.ModuleList(blocks) - self.attn_blocks = nn.ModuleList(attn_blocks) - self.quantize_blocks = nn.ModuleList(quantize_blocks) - self.do_checkpointing = do_checkpointing - - chan_last = filters[-1] - latent_dim = 2 * 2 * chan_last - - self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode) - self.flatten = nn.Flatten() - if mlp: - self.to_logit = nn.Sequential(TransferLinear(latent_dim, 100, transfer_mode=transfer_mode), - leaky_relu(), - TransferLinear(100, 1, transfer_mode=transfer_mode)) - else: - self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode) - - self._init_weights() - - self.transfer_mode = transfer_mode - if transfer_mode: - for p in self.parameters(): - if not hasattr(p, 'FOR_TRANSFER_LEARNING'): - p.DO_NOT_TRAIN = True - - def forward(self, x): - b, *_ = x.shape - - quantize_loss = torch.zeros(1).to(x) - - for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks): - if self.do_checkpointing: - x = checkpoint(block, x) - else: - x = block(x) - - if exists(attn_block): - x = attn_block(x) - - if exists(q_block): - x, _, loss = q_block(x) - quantize_loss += loss - - x = self.final_conv(x) - x = self.flatten(x) - x = self.to_logit(x) - if exists(q_block): - return x.squeeze(), quantize_loss - else: - return x.squeeze() - - def _init_weights(self): - for m in self.modules(): - if type(m) in {TransferConv2d, TransferLinear}: - nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') - - # Configures the network as partially pre-trained. This means: - # 1) The top (high-resolution) `num_blocks` will have their weights re-initialized. - # 2) The head (linear layers) will also have their weights re-initialized - # 3) All intermediate blocks will be frozen until step `frozen_until_step` - # These settings will be applied after the weights have been loaded (network_loaded()) - def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0): - self.bypass_blocks = bypass_blocks - self.num_blocks = num_blocks - self.frozen_until_step = frozen_until_step - - # Called after the network weights are loaded. - def network_loaded(self): - if not hasattr(self, 'frozen_until_step'): - return - - if self.bypass_blocks > 0: - self.blocks = self.blocks[self.bypass_blocks:] - self.blocks[0] = DiscriminatorBlock(3, self.blocks[0].filters, downsample=True).to(next(self.parameters()).device) - - reset_blocks = [self.to_logit] - for i in range(self.num_blocks): - reset_blocks.append(self.blocks[i]) - for bl in reset_blocks: - for m in bl.modules(): - if type(m) in {TransferConv2d, TransferLinear}: - nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') - for p in m.parameters(recurse=True): - p._NEW_BLOCK = True - for p in self.parameters(): - if not hasattr(p, '_NEW_BLOCK'): - p.DO_NOT_TRAIN_UNTIL = self.frozen_until_step - - -# helper classes -def DiffAugment(x, types=[]): - for p in types: - for f in AUGMENT_FNS[p]: - x = f(x) - return x.contiguous() - - -def random_hflip(tensor, prob): - if prob > random(): - return tensor - return torch.flip(tensor, dims=(3,)) - - -def rand_brightness(x): - x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) - return x - - -def rand_saturation(x): - x_mean = x.mean(dim=1, keepdim=True) - x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean - return x - - -def rand_contrast(x): - x_mean = x.mean(dim=[1, 2, 3], keepdim=True) - x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean - return x - - -def rand_translation(x, ratio=0.125): - shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) - translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) - translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) - grid_batch, grid_x, grid_y = torch.meshgrid( - torch.arange(x.size(0), dtype=torch.long, device=x.device), - torch.arange(x.size(2), dtype=torch.long, device=x.device), - torch.arange(x.size(3), dtype=torch.long, device=x.device), - ) - grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) - grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) - x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) - x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) - return x - - -def rand_cutout(x, ratio=0.5): - cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) - offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) - offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) - grid_batch, grid_x, grid_y = torch.meshgrid( - torch.arange(x.size(0), dtype=torch.long, device=x.device), - torch.arange(cutout_size[0], dtype=torch.long, device=x.device), - torch.arange(cutout_size[1], dtype=torch.long, device=x.device), - ) - grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) - grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) - mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) - mask[grid_batch, grid_x, grid_y] = 0 - x = x * mask.unsqueeze(1) - return x - - -AUGMENT_FNS = { - 'color': [rand_brightness, rand_saturation, rand_contrast], - 'translation': [rand_translation], - 'cutout': [rand_cutout], -} - - -class DiscAugmentor(nn.Module): - def __init__(self, D, image_size, types, prob): - super().__init__() - self.D = D - self.prob = prob - self.types = types - - def forward(self, images, real_images=False): - if random() < self.prob: - images = random_hflip(images, prob=0.5) - images = DiffAugment(images, types=self.types) - - if real_images: - self.hq_aug = images.detach().clone() - else: - self.gen_aug = images.detach().clone() - - # Save away for use elsewhere (e.g. unet loss) - self.aug_images = images - - return self.D(images) - - def network_loaded(self): - self.D.network_loaded() - - # Allows visualizing what the augmentor is up to. - def visual_dbg(self, step, path): - torchvision.utils.save_image(self.gen_aug, os.path.join(path, "%i_gen_aug.png" % (step))) - torchvision.utils.save_image(self.hq_aug, os.path.join(path, "%i_hq_aug.png" % (step))) - - -def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): - if fp16: - with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss: - scaled_loss.backward(**kwargs) - else: - loss.backward(**kwargs) - - -def gradient_penalty(images, output, weight=10, return_structured_grads=False): - batch_size = images.shape[0] - gradients = torch_grad(outputs=output, inputs=images, - grad_outputs=torch.ones(output.size(), device=images.device), - create_graph=True, retain_graph=True, only_inputs=True)[0] - - flat_grad = gradients.reshape(batch_size, -1) - penalty = weight * ((flat_grad.norm(2, dim=1) - 1) ** 2).mean() - if return_structured_grads: - return penalty, gradients - else: - return penalty - - -class StyleSrGanDivergenceLoss(L.ConfigurableLoss): - def __init__(self, opt, env): - super().__init__(opt, env) - self.real = opt['real'] - self.fake = opt['fake'] - self.discriminator = opt['discriminator'] - self.for_gen = opt['gen_loss'] - self.gp_frequency = opt['gradient_penalty_frequency'] - self.noise = opt['noise'] if 'noise' in opt.keys() else 0 - - def forward(self, net, state): - real_input = state[self.real] - fake_input = state[self.fake] - if self.noise != 0: - fake_input = fake_input + torch.rand_like(fake_input) * self.noise - real_input = real_input + torch.rand_like(real_input) * self.noise - - D = self.env['discriminators'][self.discriminator] - fake = D(fake_input, real_images=False) - if self.for_gen: - return fake.mean() - else: - real_input.requires_grad_() # <-- Needed to compute gradients on the input. - real = D(real_input, real_images=True) - divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() - - # Apply gradient penalty. TODO: migrate this elsewhere. - if self.env['step'] % self.gp_frequency == 0: - gp = gradient_penalty(real_input, real) - self.metrics.append(("gradient_penalty", gp.clone().detach())) - divergence_loss = divergence_loss + gp - - real_input.requires_grad_(requires_grad=False) - return divergence_loss - - -@register_model -def register_styledsr_discriminator(opt_net, opt): - attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] - disc = StyleSrDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn, - do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False), - quantize=opt_get(opt_net, ['quantize'], False), - mlp=opt_get(opt_net, ['mlp_head'], True), - transfer_mode=opt_get(opt_net, ['transfer_mode'], False) - ) - if 'use_partial_pretrained' in opt_net.keys(): - disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until']) - return DiscAugmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) diff --git a/codes/models/styled_sr/styled_sr.py b/codes/models/styled_sr/styled_sr.py deleted file mode 100644 index 07447dd7..00000000 --- a/codes/models/styled_sr/styled_sr.py +++ /dev/null @@ -1,199 +0,0 @@ -from random import random - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.arch_util import kaiming_init -from models.styled_sr.stylegan2_base import StyleVectorizer, GeneratorBlock -from models.styled_sr.transfer_primitives import TransferConvGnLelu, TransferConv2d, TransferLinear -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -def rrdb_init_weights(module, scale=1): - for m in module.modules(): - if isinstance(m, TransferConv2d): - kaiming_init(m, a=0, mode='fan_in', bias=0) - m.weight.data *= scale - elif isinstance(m, TransferLinear): - kaiming_init(m, a=0, mode='fan_in', bias=0) - m.weight.data *= scale - - -class EncoderRRDB(nn.Module): - def __init__(self, mid_channels=64, output_channels=32, growth_channels=32, init_weight=.1, transfer_mode=False): - super(EncoderRRDB, self).__init__() - for i in range(5): - out_channels = output_channels if i == 4 else growth_channels - self.add_module( - f'conv{i+1}', - TransferConv2d(mid_channels + i * growth_channels, out_channels, 3, 1, 1, transfer_mode=transfer_mode)) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - for i in range(5): - rrdb_init_weights(getattr(self, f'conv{i+1}'), init_weight) - - def forward(self, x): - x1 = self.lrelu(self.conv1(x)) - x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) - x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 - - -class StyledSrEncoder(nn.Module): - def __init__(self, fea_out=256, initial_stride=1, transfer_mode=False): - super().__init__() - # Current assumes fea_out=256. - self.initial_conv = TransferConvGnLelu(3, 32, kernel_size=7, stride=initial_stride, norm=False, activation=False, bias=True, transfer_mode=transfer_mode) - self.rrdbs = nn.ModuleList([ - EncoderRRDB(32, transfer_mode=transfer_mode), - EncoderRRDB(64, transfer_mode=transfer_mode), - EncoderRRDB(96, transfer_mode=transfer_mode), - EncoderRRDB(128, transfer_mode=transfer_mode), - EncoderRRDB(160, transfer_mode=transfer_mode), - EncoderRRDB(192, transfer_mode=transfer_mode), - EncoderRRDB(224, transfer_mode=transfer_mode)]) - - def forward(self, x): - fea = self.initial_conv(x) - for rrdb in self.rrdbs: - fea = torch.cat([fea, checkpoint(rrdb, fea)], dim=1) - return fea - - -class Generator(nn.Module): - def __init__(self, image_size, latent_dim, initial_stride=1, start_level=3, upsample_levels=2, transfer_mode=False): - super().__init__() - total_levels = upsample_levels + 1 # The first level handles the raw encoder output and doesn't upsample. - self.image_size = image_size - self.scale = 2 ** upsample_levels - self.latent_dim = latent_dim - self.num_layers = total_levels - self.transfer_mode = transfer_mode - filters = [ - 512, # 4x4 - 512, # 8x8 - 512, # 16x16 - 256, # 32x32 - 128, # 64x64 - 64, # 128x128 - 32, # 256x256 - 16, # 512x512 - 8, # 1024x1024 - ] - - # I'm making a guess here that the encoder does not need transfer learning, hence fixed transfer_mode=False. This should be vetted. - self.encoder = StyledSrEncoder(filters[start_level], initial_stride, transfer_mode=False) - - in_out_pairs = list(zip(filters[:-1], filters[1:])) - self.blocks = nn.ModuleList([]) - for ind in range(start_level, start_level+total_levels): - in_chan, out_chan = in_out_pairs[ind] - not_first = ind != start_level - not_last = ind != (start_level+total_levels-1) - block = GeneratorBlock( - latent_dim, - in_chan, - out_chan, - upsample=not_first, - upsample_rgb=not_last, - transfer_learning_mode=transfer_mode - ) - self.blocks.append(block) - - def forward(self, lr, styles): - b, c, h, w = lr.shape - if self.transfer_mode: - with torch.no_grad(): - x = self.encoder(lr) - else: - x = self.encoder(lr) - - styles = styles.transpose(0, 1) - input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device) - if h != x.shape[-2]: - rgb = F.interpolate(lr, size=x.shape[2:], mode="area") - else: - rgb = lr - - for style, block in zip(styles, self.blocks): - x, rgb = checkpoint(block, x, rgb, style, input_noise) - - return rgb - - -class StyledSrGenerator(nn.Module): - def __init__(self, image_size, initial_stride=1, latent_dim=512, style_depth=8, lr_mlp=.1, transfer_mode=False): - super().__init__() - # Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later. - self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False) - self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode) - self.l2 = nn.MSELoss() - self.mixed_prob = .9 - self._init_weights() - self.transfer_mode = transfer_mode - self.initial_stride = initial_stride - if transfer_mode: - for p in self.parameters(): - if not hasattr(p, 'FOR_TRANSFER_LEARNING'): - p.DO_NOT_TRAIN = True - - - def _init_weights(self): - for m in self.modules(): - if type(m) in {TransferConv2d, TransferLinear} and hasattr(m, 'weight'): - nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') - - for block in self.gen.blocks: - nn.init.zeros_(block.to_noise1.weight) - nn.init.zeros_(block.to_noise2.weight) - nn.init.zeros_(block.to_noise1.bias) - nn.init.zeros_(block.to_noise2.bias) - - def forward(self, x): - b, f, h, w = x.shape - - # Synthesize style latents from noise. - style = torch.randn(b*2, self.gen.latent_dim).to(x.device) - if self.transfer_mode: - with torch.no_grad(): - w = self.vectorizer(style) - else: - w = self.vectorizer(style) - - # Randomly distribute styles across layers - w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone() - for j in range(b): - cutoff = int(torch.rand(()).numpy() * self.gen.num_layers) - if cutoff == self.gen.num_layers or random() > self.mixed_prob: - w_styles[j] = w_styles[j*2] - else: - w_styles[j, :cutoff] = w_styles[j*2, :cutoff] - w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:] - w_styles = w_styles[:b] - - out = self.gen(x, w_styles) - - # Compute an L2 loss on the areal interpolation of the generated image back down to LR * initial_stride; used - # for regularization. - out_down = F.interpolate(out, size=(x.shape[-2] // self.initial_stride, x.shape[-1] // self.initial_stride), mode="area") - if self.initial_stride > 1: - x = F.interpolate(x, scale_factor=1/self.initial_stride, mode="area") - l2_reg = self.l2(x, out_down) - - return out, l2_reg, w_styles - - -if __name__ == '__main__': - gen = StyledSrGenerator(128, 2) - out = gen(torch.rand(1,3,64,64)) - print([o.shape for o in out]) - - -@register_model -def register_styled_sr(opt_net, opt): - return StyledSrGenerator(128, - initial_stride=opt_get(opt_net, ['initial_stride'], 1), - transfer_mode=opt_get(opt_net, ['transfer_mode'], False)) diff --git a/codes/models/styled_sr/stylegan2_base.py b/codes/models/styled_sr/stylegan2_base.py deleted file mode 100644 index dff9be14..00000000 --- a/codes/models/styled_sr/stylegan2_base.py +++ /dev/null @@ -1,411 +0,0 @@ -import math -import multiprocessing -from contextlib import contextmanager, ExitStack - -import torch -import torch.nn.functional as F -from kornia.filters import filter2D -from linear_attention_transformer import ImageLinearAttention -from torch import nn, Tensor -from torch.autograd import grad as torch_grad -from torch.nn import Parameter, init -from torch.nn.modules.conv import _ConvNd - -from models.styled_sr.transfer_primitives import TransferLinear - -assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' - -num_cores = multiprocessing.cpu_count() - -# constants -EPS = 1e-8 - - -class NanException(Exception): - pass - - -class EMA(): - def __init__(self, beta): - super().__init__() - self.beta = beta - - def update_average(self, old, new): - if not exists(old): - return new - return old * self.beta + (1 - self.beta) * new - - -class Flatten(nn.Module): - def forward(self, x): - return x.reshape(x.shape[0], -1) - - -class Residual(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - - def forward(self, x): - return self.fn(x) + x - - -class Rezero(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - self.g = nn.Parameter(torch.zeros(1)) - - def forward(self, x): - return self.fn(x) * self.g - - -class PermuteToFrom(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - - def forward(self, x): - x = x.permute(0, 2, 3, 1) - out, loss = self.fn(x) - out = out.permute(0, 3, 1, 2) - return out, loss - - -class Blur(nn.Module): - def __init__(self): - super().__init__() - f = torch.Tensor([1, 2, 1]) - self.register_buffer('f', f) - - def forward(self, x): - f = self.f - f = f[None, None, :] * f[None, :, None] - return filter2D(x, f, normalized=True) - - -# one layer of self-attention and feedforward, for images - -attn_and_ff = lambda chan: nn.Sequential(*[ - Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))), - Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) -]) - - -# helpers - -def exists(val): - return val is not None - - -@contextmanager -def null_context(): - yield - - -def combine_contexts(contexts): - @contextmanager - def multi_contexts(): - with ExitStack() as stack: - yield [stack.enter_context(ctx()) for ctx in contexts] - - return multi_contexts - - -def default(value, d): - return value if exists(value) else d - - -def cycle(iterable): - while True: - for i in iterable: - yield i - - -def cast_list(el): - return el if isinstance(el, list) else [el] - - -def is_empty(t): - if isinstance(t, torch.Tensor): - return t.nelement() == 0 - return not exists(t) - - -def raise_if_nan(t): - if torch.isnan(t): - raise NanException - - -def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): - if is_ddp: - num_no_syncs = gradient_accumulate_every - 1 - head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs - tail = [null_context] - contexts = head + tail - else: - contexts = [null_context] * gradient_accumulate_every - - for context in contexts: - with context(): - yield - - -def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): - if fp16: - with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss: - scaled_loss.backward(**kwargs) - else: - loss.backward(**kwargs) - -def calc_pl_lengths(styles, images): - device = images.device - num_pixels = images.shape[2] * images.shape[3] - pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels) - outputs = (images * pl_noise).sum() - - pl_grads = torch_grad(outputs=outputs, inputs=styles, - grad_outputs=torch.ones(outputs.shape, device=device), - create_graph=True, retain_graph=True, only_inputs=True)[0] - - return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt() - - -def image_noise(n, im_size, device): - return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device) - - -def leaky_relu(p=0.2): - return nn.LeakyReLU(p, inplace=True) - - -def evaluate_in_chunks(max_batch_size, model, *args): - split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) - chunked_outputs = [model(*i) for i in split_args] - if len(chunked_outputs) == 1: - return chunked_outputs[0] - return torch.cat(chunked_outputs, dim=0) - - -def set_requires_grad(model, bool): - for p in model.parameters(): - p.requires_grad = bool - - -def slerp(val, low, high): - low_norm = low / torch.norm(low, dim=1, keepdim=True) - high_norm = high / torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm * high_norm).sum(1)) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high - return res - - -class EqualLinear(nn.Module): - def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, transfer_mode=False): - super().__init__() - self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) - if bias: - self.bias = nn.Parameter(torch.zeros(out_dim)) - - self.lr_mul = lr_mul - - self.transfer_mode = transfer_mode - if transfer_mode: - self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features)) - self.transfer_scale.FOR_TRANSFER_LEARNING = True - self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features)) - self.transfer_shift.FOR_TRANSFER_LEARNING = True - - def forward(self, input): - if self.transfer_mode: - weight = self.weight * self.transfer_scale + self.transfer_shift - else: - weight = self.weight - return F.linear(input, weight * self.lr_mul, bias=self.bias * self.lr_mul) - - -class StyleVectorizer(nn.Module): - def __init__(self, emb, depth, lr_mul=0.1, transfer_mode=False): - super().__init__() - - layers = [] - for i in range(depth): - layers.extend([EqualLinear(emb, emb, lr_mul, transfer_mode=transfer_mode), leaky_relu()]) - - self.net = nn.Sequential(*layers) - - def forward(self, x): - x = F.normalize(x, dim=1) - return self.net(x) - - -class RGBBlock(nn.Module): - def __init__(self, latent_dim, input_channel, upsample, rgba=False, transfer_mode=False): - super().__init__() - self.input_channel = input_channel - self.to_style = nn.Linear(latent_dim, input_channel) - - out_filters = 3 if not rgba else 4 - self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False, transfer_mode=transfer_mode) - - self.upsample = nn.Sequential( - nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), - Blur() - ) if upsample else None - - def forward(self, x, prev_rgb, istyle): - b, c, h, w = x.shape - style = self.to_style(istyle) - x = self.conv(x, style) - - if exists(prev_rgb): - x = x + prev_rgb - - if exists(self.upsample): - x = self.upsample(x) - - return x - - -class AdaptiveInstanceNorm(nn.Module): - def __init__(self, in_channel, style_dim): - super().__init__() - from models.archs.arch_util import ConvGnLelu - self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True) - self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0) - self.norm = nn.InstanceNorm2d(in_channel) - - def forward(self, input, style): - gamma = self.style2scale(style) - beta = self.style2bias(style) - out = self.norm(input) - out = gamma * out + beta - return out - - -class NoiseInjection(nn.Module): - def __init__(self, channel): - super().__init__() - self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) - - def forward(self, image, noise): - return image + self.weight * noise - - -class EqualLR: - def __init__(self, name): - self.name = name - - def compute_weight(self, module): - weight = getattr(module, self.name + '_orig') - fan_in = weight.data.size(1) * weight.data[0][0].numel() - - return weight * math.sqrt(2 / fan_in) - - @staticmethod - def apply(module, name): - fn = EqualLR(name) - - weight = getattr(module, name) - del module._parameters[name] - module.register_parameter(name + '_orig', nn.Parameter(weight.data)) - module.register_forward_pre_hook(fn) - - return fn - - def __call__(self, module, input): - weight = self.compute_weight(module) - setattr(module, self.name, weight) - - -def equal_lr(module, name='weight'): - EqualLR.apply(module, name) - return module - - -class Conv2DMod(nn.Module): - def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, transfer_mode=False, **kwargs): - super().__init__() - self.filters = out_chan - self.demod = demod - self.kernel = kernel - self.stride = stride - self.dilation = dilation - self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel))) - nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') - self.transfer_mode = transfer_mode - if transfer_mode: - self.transfer_scale = nn.Parameter(torch.ones(out_chan, in_chan, 1, 1)) - self.transfer_scale.FOR_TRANSFER_LEARNING = True - self.transfer_shift = nn.Parameter(torch.zeros(out_chan, in_chan, 1, 1)) - self.transfer_shift.FOR_TRANSFER_LEARNING = True - - def _get_same_padding(self, size, kernel, dilation, stride): - return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 - - def forward(self, x, y): - b, c, h, w = x.shape - - if self.transfer_mode: - weight = self.weight * self.transfer_scale + self.transfer_shift - else: - weight = self.weight - - w1 = y[:, None, :, None, None] - w2 = weight[None, :, :, :, :] - weights = w2 * (w1 + 1) - - if self.demod: - d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS) - weights = weights * d - - x = x.reshape(1, -1, h, w) - - _, _, *ws = weights.shape - weights = weights.reshape(b * self.filters, *ws) - - padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride) - x = F.conv2d(x, weights, padding=padding, groups=b) - - x = x.reshape(-1, self.filters, h, w) - return x - - -class GeneratorBlock(nn.Module): - def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, - transfer_learning_mode=False): - super().__init__() - self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None - - self.to_style1 = TransferLinear(latent_dim, input_channels, transfer_mode=transfer_learning_mode) - self.to_noise1 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode) - self.conv1 = Conv2DMod(input_channels, filters, 3, transfer_mode=transfer_learning_mode) - - self.to_style2 = TransferLinear(latent_dim, filters, transfer_mode=transfer_learning_mode) - self.to_noise2 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode) - self.conv2 = Conv2DMod(filters, filters, 3, transfer_mode=transfer_learning_mode) - - self.activation = leaky_relu() - self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba, transfer_mode=transfer_learning_mode) - - self.transfer_learning_mode = transfer_learning_mode - - def forward(self, x, prev_rgb, istyle, inoise): - if exists(self.upsample): - x = self.upsample(x) - - inoise = inoise[:, :x.shape[2], :x.shape[3], :] - noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2)) - noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2)) - - style1 = self.to_style1(istyle) - x = self.conv1(x, style1) - x = self.activation(x + noise1) - - style2 = self.to_style2(istyle) - x = self.conv2(x, style2) - x = self.activation(x + noise2) - - rgb = self.to_rgb(x, prev_rgb, istyle) - return x, rgb diff --git a/codes/models/styled_sr/transfer_primitives.py b/codes/models/styled_sr/transfer_primitives.py deleted file mode 100644 index 93af5391..00000000 --- a/codes/models/styled_sr/transfer_primitives.py +++ /dev/null @@ -1,136 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.nn import Parameter, init -from torch.nn.modules.conv import _ConvNd -from torch.nn.modules.utils import _ntuple - -_pair = _ntuple(2) - -class TransferConv2d(_ConvNd): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size, - stride = 1, - padding = 0, - dilation = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = 'zeros', - transfer_mode: bool = False - ): - kernel_size = _pair(kernel_size) - stride = _pair(stride) - padding = _pair(padding) - dilation = _pair(dilation) - super().__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - False, _pair(0), groups, bias, padding_mode) - - self.transfer_mode = transfer_mode - if transfer_mode: - self.transfer_scale = nn.Parameter(torch.ones(out_channels, in_channels, 1, 1)) - self.transfer_scale.FOR_TRANSFER_LEARNING = True - self.transfer_shift = nn.Parameter(torch.zeros(out_channels, in_channels, 1, 1)) - self.transfer_shift.FOR_TRANSFER_LEARNING = True - - def _conv_forward(self, input, weight): - if self.transfer_mode: - weight = weight * self.transfer_scale + self.transfer_shift - else: - weight = weight - - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.bias, self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.bias, self.stride, - self.padding, self.dilation, self.groups) - - def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.weight) - - -class TransferLinear(nn.Module): - __constants__ = ['in_features', 'out_features'] - in_features: int - out_features: int - weight: Tensor - - def __init__(self, in_features: int, out_features: int, bias: bool = True, transfer_mode: bool = False) -> None: - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = Parameter(torch.Tensor(out_features, in_features)) - if bias: - self.bias = Parameter(torch.Tensor(out_features)) - else: - self.register_parameter('bias', None) - self.reset_parameters() - self.transfer_mode = transfer_mode - if transfer_mode: - self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features)) - self.transfer_scale.FOR_TRANSFER_LEARNING = True - self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features)) - self.transfer_shift.FOR_TRANSFER_LEARNING = True - - def reset_parameters(self) -> None: - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) - init.uniform_(self.bias, -bound, bound) - - def forward(self, input: Tensor) -> Tensor: - if self.transfer_mode: - weight = self.weight * self.transfer_scale + self.transfer_shift - else: - weight = self.weight - return F.linear(input, weight, self.bias) - - def extra_repr(self) -> str: - return 'in_features={}, out_features={}, bias={}'.format( - self.in_features, self.out_features, self.bias is not None - ) - - -class TransferConvGnLelu(nn.Module): - def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, transfer_mode=False): - super().__init__() - padding_map = {1: 0, 3: 1, 5: 2, 7: 3} - assert kernel_size in padding_map.keys() - self.conv = TransferConv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias, transfer_mode=transfer_mode) - if norm: - self.gn = nn.GroupNorm(num_groups, filters_out) - else: - self.gn = None - if activation: - self.lelu = nn.LeakyReLU(negative_slope=.2) - else: - self.lelu = None - - # Init params. - for m in self.modules(): - if isinstance(m, TransferConv2d): - nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out', - nonlinearity='leaky_relu' if self.lelu else 'linear') - m.weight.data *= weight_init_factor - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def forward(self, x): - x = self.conv(x) - if self.gn: - x = self.gn(x) - if self.lelu: - return self.lelu(x) - else: - return x \ No newline at end of file diff --git a/codes/models/tecogan/__init__.py b/codes/models/tecogan/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/models/tecogan/flownet2.py b/codes/models/tecogan/flownet2.py deleted file mode 100644 index c5c108b3..00000000 --- a/codes/models/tecogan/flownet2.py +++ /dev/null @@ -1,15 +0,0 @@ -import munch -import torch - -from trainer.networks import register_model - - -@register_model -def register_flownet2(opt_net): - from models.flownet2.models import FlowNet2 - ld = 'load_path' in opt_net.keys() - args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld}) - netG = FlowNet2(args) - if ld: - sd = torch.load(opt_net['load_path']) - netG.load_state_dict(sd['state_dict']) \ No newline at end of file diff --git a/codes/models/tecogan/teco_resgen.py b/codes/models/tecogan/teco_resgen.py deleted file mode 100644 index 2f640148..00000000 --- a/codes/models/tecogan/teco_resgen.py +++ /dev/null @@ -1,79 +0,0 @@ -import os - -import torch -import torch.nn as nn -import torchvision - -from trainer.networks import register_model -from utils.util import sequential_checkpoint -from models.arch_util import ConvGnSilu, make_layer - - -class TecoResblock(nn.Module): - def __init__(self, nf): - super(TecoResblock, self).__init__() - self.nf = nf - self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False, weight_init_factor=.1) - self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False, weight_init_factor=.1) - - def forward(self, x): - identity = x - x = self.conv1(x) - x = self.conv2(x) - return identity + x - - -class TecoUpconv(nn.Module): - def __init__(self, nf, scale): - super(TecoUpconv, self).__init__() - self.nf = nf - self.scale = scale - self.conv1 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.conv2 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.conv3 = ConvGnSilu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.final_conv = ConvGnSilu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False) - - def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - x = nn.functional.interpolate(x, scale_factor=self.scale, mode="nearest") - x = self.conv3(x) - return self.final_conv(x) - - -# Extremely simple resnet based generator that is very similar to the one used in the tecogan paper. -# Main differences: -# - Uses SiLU instead of ReLU -# - Reference input is in HR space (just makes more sense) -# - Doesn't use transposed convolutions - just uses interpolation instead. -# - Upsample block is slightly more complicated. -class TecoGen(nn.Module): - def __init__(self, nf, scale): - super(TecoGen, self).__init__() - self.nf = nf - self.scale = scale - fea_conv = ConvGnSilu(6, nf, kernel_size=7, stride=self.scale, bias=True, norm=False, activation=True) - res_layers = [TecoResblock(nf) for i in range(15)] - upsample = TecoUpconv(nf, scale) - everything = [fea_conv] + res_layers + [upsample] - self.core = nn.Sequential(*everything) - - def forward(self, x, ref=None): - x = nn.functional.interpolate(x, scale_factor=self.scale, mode="bicubic") - if ref is None: - ref = torch.zeros_like(x) - join = torch.cat([x, ref], dim=1) - join = sequential_checkpoint(self.core, 6, join) - self.join = join.detach().clone() + .5 - return x + join - - def visual_dbg(self, step, path): - torchvision.utils.save_image(self.join.cpu().float(), os.path.join(path, "%i_join.png" % (step,))) - - def get_debug_values(self, step, net_name): - return {'branch_std': self.join.std()} - - -@register_model -def register_tecogen(opt_net, opt): - return TecoGen(opt_net['nf'], opt_net['scale']) \ No newline at end of file diff --git a/codes/models/transformers/__init__.py b/codes/models/transformers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/models/transformers/igpt/__init__.py b/codes/models/transformers/igpt/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/models/transformers/igpt/gpt2.py b/codes/models/transformers/igpt/gpt2.py deleted file mode 100644 index 5e0d0eeb..00000000 --- a/codes/models/transformers/igpt/gpt2.py +++ /dev/null @@ -1,155 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from trainer.inject import Injector -from trainer.networks import register_model -from utils.util import checkpoint - - -def create_injector(opt, env): - type = opt['type'] - if type == 'igpt_resolve': - return ResolveInjector(opt, env) - return None - - -class ResolveInjector(Injector): - def __init__(self, opt, env): - super().__init__(opt, env) - self.gen = opt['generator'] - self.samples = opt['num_samples'] - self.temperature = opt['temperature'] - - def forward(self, state): - gen = self.env['generators'][self.opt['generator']].module - img = state[self.input] - b, c, h, w = img.shape - qimg = gen.quantize(img) - s, b = qimg.shape - qimg = qimg[:s//2, :] - output = qimg.repeat(1, self.samples) - - pad = torch.zeros(1, self.samples, dtype=torch.long).cuda() # to pad prev output - with torch.no_grad(): - for _ in range(s//2): - logits, _ = gen(torch.cat((output, pad), dim=0), already_quantized=True) - logits = logits[-1, :, :] / self.temperature - probs = F.softmax(logits, dim=-1) - pred = torch.multinomial(probs, num_samples=1).transpose(1, 0) - output = torch.cat((output, pred), dim=0) - output = gen.unquantize(output.reshape(h, w, -1)) - return {self.output: output.permute(2,3,0,1).contiguous()} - - -class Block(nn.Module): - def __init__(self, embed_dim, num_heads): - super(Block, self).__init__() - self.ln_1 = nn.LayerNorm(embed_dim) - self.ln_2 = nn.LayerNorm(embed_dim) - self.attn = nn.MultiheadAttention(embed_dim, num_heads) - self.mlp = nn.Sequential( - nn.Linear(embed_dim, embed_dim * 4), - nn.GELU(), - nn.Linear(embed_dim * 4, embed_dim), - ) - - def forward(self, x): - attn_mask = torch.full( - (len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype - ) - attn_mask = torch.triu(attn_mask, diagonal=1) - - x = self.ln_1(x) - a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False) - x = x + a - m = self.mlp(self.ln_2(x)) - x = x + m - return x - - -class iGPT2(nn.Module): - def __init__( - self, embed_dim, num_heads, num_layers, num_positions, num_vocab, centroids_file - ): - super().__init__() - - self.centroids = nn.Parameter( - torch.from_numpy(np.load(centroids_file)), requires_grad=False - ) - self.embed_dim = embed_dim - - # start of sequence token - self.sos = torch.nn.Parameter(torch.zeros(embed_dim)) - nn.init.normal_(self.sos) - - self.token_embeddings = nn.Embedding(num_vocab, embed_dim) - self.position_embeddings = nn.Embedding(num_positions, embed_dim) - - self.layers = nn.ModuleList() - for _ in range(num_layers): - self.layers.append(Block(embed_dim, num_heads)) - - self.ln_f = nn.LayerNorm(embed_dim) - self.head = nn.Linear(embed_dim, num_vocab, bias=False) - self.clf_head = nn.Linear(embed_dim, 10) # Fixed num_classes, this is not a classifier. - - def squared_euclidean_distance(self, a, b): - b = torch.transpose(b, 0, 1) - a2 = torch.sum(torch.square(a), dim=1, keepdims=True) - b2 = torch.sum(torch.square(b), dim=0, keepdims=True) - ab = torch.matmul(a, b) - d = a2 - 2 * ab + b2 - return d - - def quantize(self, x): - b, c, h, w = x.shape - # [B, C, H, W] => [B, H, W, C] - x = x.permute(0, 2, 3, 1).contiguous() - x = x.view(-1, c) # flatten to pixels - d = self.squared_euclidean_distance(x, self.centroids) - x = torch.argmin(d, 1) - x = x.view(b, h, w) - - # Reshape output to [seq_len, batch]. - x = x.view(x.shape[0], -1) # flatten images into sequences - x = x.transpose(0, 1).contiguous() # to shape [seq len, batch] - return x - - def unquantize(self, x): - return self.centroids[x] - - def forward(self, x, already_quantized=False): - """ - Expect input as shape [b, c, h, w] - """ - - if not already_quantized: - x = self.quantize(x) - length, batch = x.shape - - h = self.token_embeddings(x) - - # prepend sos token - sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos - h = torch.cat([sos, h[:-1, :, :]], axis=0) - - # add positional embeddings - positions = torch.arange(length, device=x.device).unsqueeze(-1) - h = h + self.position_embeddings(positions).expand_as(h) - - # transformer - for layer in self.layers: - h = checkpoint(layer, h) - - h = self.ln_f(h) - - logits = self.head(h) - - return logits, x - - -@register_model -def register_igpt2(opt_net, opt): - return iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, - opt_net['num_vocab'], centroids_file=opt_net['centroids_file']) diff --git a/codes/scripts/cifar100_untangle.py b/codes/scripts/cifar100_untangle.py deleted file mode 100644 index 1f8f67f3..00000000 --- a/codes/scripts/cifar100_untangle.py +++ /dev/null @@ -1,42 +0,0 @@ -import numpy -import torch -from torch.utils.data import DataLoader - -from data.torch_dataset import TorchDataset -from models.classifiers.cifar_resnet_branched import ResNet -from models.classifiers.cifar_resnet_branched import BasicBlock - -if __name__ == '__main__': - dopt = { - 'flip': True, - 'crop_sz': None, - 'dataset': 'cifar100', - 'image_size': 32, - 'normalize': False, - 'kwargs': { - 'root': 'E:\\4k6k\\datasets\\images\\cifar100', - 'download': True - } - } - set = TorchDataset(dopt) - loader = DataLoader(set, num_workers=0, batch_size=32) - model = ResNet(BasicBlock, [2, 2, 2, 2]) - model.load_state_dict(torch.load('C:\\Users\\jbetk\\Downloads\\cifar_hardw_10000.pth')) - model.eval() - - bins = [[] for _ in range(8)] - for i, batch in enumerate(loader): - logits, selector = model(batch['hq'], coarse_label=None, return_selector=True) - for k, s in enumerate(selector): - for j, b in enumerate(s): - if b: - bins[j].append(batch['labels'][k].item()) - if i > 10: - break - - import matplotlib.pyplot as plt - fig, axs = plt.subplots(3,3) - for i in range(8): - axs[i%3, i//3].hist(numpy.asarray(bins[i])) - plt.show() - print('hi') \ No newline at end of file diff --git a/codes/scripts/compute_fdpl_perceptual_weights.py b/codes/scripts/compute_fdpl_perceptual_weights.py deleted file mode 100644 index 3b411721..00000000 --- a/codes/scripts/compute_fdpl_perceptual_weights.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import numpy as np -from utils import options as option -from data import create_dataloader, create_dataset -import math -from tqdm import tqdm -from utils.fdpl_util import dct_2d, extract_patches_2d -import matplotlib.pyplot as plt -from mpl_toolkits.axes_grid1 import make_axes_locatable -from utils.colors import rgb2ycbcr -import torch.nn.functional as F - -input_config = "../../options/train_imgset_pixgan_srg4_fdpl.yml" -output_file = "fdpr_diff_means.pt" -device = 'cuda' -patch_size=128 - -if __name__ == '__main__': - opt = option.parse(input_config, is_train=True) - opt['dist'] = False - - # Create a dataset to load from (this dataset loads HR/LR images and performs any distortions specified by the YML. - dataset_opt = opt['datasets']['train'] - train_set = create_dataset(dataset_opt) - train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) - total_iters = int(opt['train']['niter']) - total_epochs = int(math.ceil(total_iters / train_size)) - train_loader = create_dataloader(train_set, dataset_opt, opt, None) - print('Number of train images: {:,d}, iters: {:,d}'.format( - len(train_set), train_size)) - - # calculate the perceptual weights - master_diff = np.zeros((patch_size, patch_size)) - num_patches = 0 - all_diff_patches = [] - tq = tqdm(train_loader) - sampled = 0 - for train_data in tq: - if sampled > 200: - break - sampled += 1 - - im = rgb2ycbcr(train_data['hq'].double()) - im_LR = rgb2ycbcr(F.interpolate(train_data['lq'].double(), - size=im.shape[2:], - mode="bicubic", align_corners=False)) - patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True) - patches_hr = dct_2d(patches_hr, norm='ortho') - patches_lr = extract_patches_2d(img=im_LR, patch_shape=(patch_size,patch_size), batch_first=True) - patches_lr = dct_2d(patches_lr, norm='ortho') - b, p, c, w, h = patches_hr.shape - diffs = torch.abs(patches_lr - patches_hr) / ((torch.abs(patches_lr) + torch.abs(patches_hr)) / 2 + .00000001) - num_patches += b * p - all_diff_patches.append(torch.sum(diffs, dim=(0, 1))) - - diff_patches = torch.stack(all_diff_patches, dim=0) - diff_means = torch.sum(diff_patches, dim=0) / num_patches - - torch.save(diff_means, output_file) - print(diff_means) - - for i in range(3): - fig, ax = plt.subplots() - divider = make_axes_locatable(ax) - cax = divider.append_axes('right', size='5%', pad=0.05) - im = ax.imshow(diff_means[i].numpy()) - ax.set_title("mean_diff for channel %i" % (i,)) - fig.colorbar(im, cax=cax, orientation='vertical') - plt.show() - diff --git a/codes/scripts/create_lmdb.py b/codes/scripts/create_lmdb.py deleted file mode 100644 index 7b6d5de4..00000000 --- a/codes/scripts/create_lmdb.py +++ /dev/null @@ -1,411 +0,0 @@ -"""Create lmdb files for [General images (291 images/DIV2K) | Vimeo90K | REDS] training datasets""" - -import sys -import os.path as osp -import glob -import pickle -from multiprocessing import Pool -import numpy as np -import lmdb -import cv2 - -sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) -import data.util as data_util # noqa: E402 -import utils.util as util # noqa: E402 - - -def main(): - dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test - mode = 'hq' # used for vimeo90k and REDS datasets - # vimeo90k: GT | LR | flow - # REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp - # train_sharp_flowx4 - if dataset == 'vimeo90k': - vimeo90k(mode) - elif dataset == 'REDS': - REDS(mode) - elif dataset == 'general': - opt = {} - opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub' - opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' - opt['name'] = 'DIV2K800_sub_GT' - general_image_folder(opt) - elif dataset == 'DIV2K_demo': - opt = {} - ## GT - opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub' - opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' - opt['name'] = 'DIV2K800_sub_GT' - general_image_folder(opt) - ## LR - opt['img_folder'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4' - opt['lmdb_save_path'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb' - opt['name'] = 'DIV2K800_sub_bicLRx4' - general_image_folder(opt) - elif dataset == 'test': - test_lmdb('../../datasets/REDS/train_sharp_wval.lmdb', 'REDS') - - -def read_image_worker(path, key): - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) - return (key, img) - - -def general_image_folder(opt): - """Create lmdb for general image folders - Users should define the keys, such as: '0321_s035' for DIV2K sub-images - If all the images have the same resolution, it will only store one copy of resolution info. - Otherwise, it will store every resolution info. - """ - #### configurations - read_all_imgs = False # whether real all images to memory with multiprocessing - # Set False for use limited memory - BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False - n_thread = 40 - ######################################################## - img_folder = opt['img_folder'] - lmdb_save_path = opt['lmdb_save_path'] - meta_info = {'name': opt['name']} - if not lmdb_save_path.endswith('.lmdb'): - raise ValueError("lmdb_save_path must end with \'lmdb\'.") - if osp.exists(lmdb_save_path): - print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) - sys.exit(1) - - #### read all the image paths to a list - print('Reading image path list ...') - all_img_list = sorted(glob.glob(osp.join(img_folder, '*'))) - keys = [] - for img_path in all_img_list: - keys.append(osp.splitext(osp.basename(img_path))[0]) - - if read_all_imgs: - #### read all images to memory (multiprocessing) - dataset = {} # store all image data. list cannot keep the order, use dict - print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) - pbar = util.ProgressBar(len(all_img_list)) - - def mycallback(arg): - '''get the image data and update pbar''' - key = arg[0] - dataset[key] = arg[1] - pbar.update('Reading {}'.format(key)) - - pool = Pool(n_thread) - for path, key in zip(all_img_list, keys): - pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) - pool.close() - pool.join() - print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) - - #### create lmdb environment - data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes - print('data size per image is: ', data_size_per_img) - data_size = data_size_per_img * len(all_img_list) - env = lmdb.open(lmdb_save_path, map_size=data_size * 10) - - #### write data to lmdb - pbar = util.ProgressBar(len(all_img_list)) - txn = env.begin(write=True) - resolutions = [] - for idx, (path, key) in enumerate(zip(all_img_list, keys)): - pbar.update('Write {}'.format(key)) - key_byte = key.encode('ascii') - data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) - if data.ndim == 2: - H, W = data.shape - C = 1 - else: - H, W, C = data.shape - txn.put(key_byte, data) - resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W)) - if not read_all_imgs and idx % BATCH == 0: - txn.commit() - txn = env.begin(write=True) - txn.commit() - env.close() - print('Finish writing lmdb.') - - #### create meta information - # check whether all the images are the same size - assert len(keys) == len(resolutions) - if len(set(resolutions)) <= 1: - meta_info['resolution'] = [resolutions[0]] - meta_info['keys'] = keys - print('All images have the same resolution. Simplify the meta info.') - else: - meta_info['resolution'] = resolutions - meta_info['keys'] = keys - print('Not all images have the same resolution. Save meta info for each image.') - - pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) - print('Finish creating lmdb meta info.') - - -def vimeo90k(mode): - """Create lmdb for the Vimeo90K dataset, each image with a fixed size - GT: [3, 256, 448] - Now only need the 4th frame, e.g., 00001_0001_4 - LR: [3, 64, 112] - 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7 - key: - Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001 - - flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3] - Each flow is calculated with GT images by PWCNet and then downsampled by 1/4 - Flow map is quantized by mmcv and saved in png format - """ - #### configurations - read_all_imgs = False # whether real all images to memory with multiprocessing - # Set False for use limited memory - BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False - if mode == 'hq': - img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences' - lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' - txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' - H_dst, W_dst = 256, 448 - elif mode == 'LR': - img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences' - lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' - txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' - H_dst, W_dst = 64, 112 - elif mode == 'flow': - img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4' - lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb' - txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' - H_dst, W_dst = 128, 112 - else: - raise ValueError('Wrong dataset mode: {}'.format(mode)) - n_thread = 40 - ######################################################## - if not lmdb_save_path.endswith('.lmdb'): - raise ValueError("lmdb_save_path must end with \'lmdb\'.") - if osp.exists(lmdb_save_path): - print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) - sys.exit(1) - - #### read all the image paths to a list - print('Reading image path list ...') - with open(txt_file) as f: - train_l = f.readlines() - train_l = [v.strip() for v in train_l] - all_img_list = [] - keys = [] - for line in train_l: - folder = line.split('/')[0] - sub_folder = line.split('/')[1] - all_img_list.extend(glob.glob(osp.join(img_folder, folder, sub_folder, '*'))) - if mode == 'flow': - for j in range(1, 4): - keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j)) - keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j)) - else: - for j in range(7): - keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1)) - all_img_list = sorted(all_img_list) - keys = sorted(keys) - if mode == 'hq': # only read the 4th frame for the GT mode - print('Only keep the 4th frame.') - all_img_list = [v for v in all_img_list if v.endswith('im4.png')] - keys = [v for v in keys if v.endswith('_4')] - - if read_all_imgs: - #### read all images to memory (multiprocessing) - dataset = {} # store all image data. list cannot keep the order, use dict - print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) - pbar = util.ProgressBar(len(all_img_list)) - - def mycallback(arg): - """get the image data and update pbar""" - key = arg[0] - dataset[key] = arg[1] - pbar.update('Reading {}'.format(key)) - - pool = Pool(n_thread) - for path, key in zip(all_img_list, keys): - pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) - pool.close() - pool.join() - print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) - - #### write data to lmdb - data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes - print('data size per image is: ', data_size_per_img) - data_size = data_size_per_img * len(all_img_list) - env = lmdb.open(lmdb_save_path, map_size=data_size * 10) - txn = env.begin(write=True) - pbar = util.ProgressBar(len(all_img_list)) - for idx, (path, key) in enumerate(zip(all_img_list, keys)): - pbar.update('Write {}'.format(key)) - key_byte = key.encode('ascii') - data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) - if 'flow' in mode: - H, W = data.shape - assert H == H_dst and W == W_dst, 'different shape.' - else: - H, W, C = data.shape - assert H == H_dst and W == W_dst and C == 3, 'different shape.' - txn.put(key_byte, data) - if not read_all_imgs and idx % BATCH == 0: - txn.commit() - txn = env.begin(write=True) - txn.commit() - env.close() - print('Finish writing lmdb.') - - #### create meta information - meta_info = {} - if mode == 'hq': - meta_info['name'] = 'Vimeo90K_train_GT' - elif mode == 'lq': - meta_info['name'] = 'Vimeo90K_train_LR' - elif mode == 'flow': - meta_info['name'] = 'Vimeo90K_train_flowx4' - channel = 1 if 'flow' in mode else 3 - meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst) - key_set = set() - for key in keys: - if mode == 'flow': - a, b, _, _ = key.split('_') - else: - a, b, _ = key.split('_') - key_set.add('{}_{}'.format(a, b)) - meta_info['keys'] = list(key_set) - pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) - print('Finish creating lmdb meta info.') - - -def REDS(mode): - """Create lmdb for the REDS dataset, each image with a fixed size - GT: [3, 720, 1280], key: 000_00000000 - LR: [3, 180, 320], key: 000_00000000 - key: 000_00000000 - - flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2] - Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4 - Flow map is quantized by mmcv and saved in png format - """ - #### configurations - read_all_imgs = False # whether real all images to memory with multiprocessing - # Set False for use limited memory - BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False - if mode == 'train_sharp': - img_folder = '../../datasets/REDS/train_sharp' - lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb' - H_dst, W_dst = 720, 1280 - elif mode == 'train_sharp_bicubic': - img_folder = '../../datasets/REDS/train_sharp_bicubic' - lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb' - H_dst, W_dst = 180, 320 - elif mode == 'train_blur_bicubic': - img_folder = '../../datasets/REDS/train_blur_bicubic' - lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb' - H_dst, W_dst = 180, 320 - elif mode == 'train_blur': - img_folder = '../../datasets/REDS/train_blur' - lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb' - H_dst, W_dst = 720, 1280 - elif mode == 'train_blur_comp': - img_folder = '../../datasets/REDS/train_blur_comp' - lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb' - H_dst, W_dst = 720, 1280 - elif mode == 'train_sharp_flowx4': - img_folder = '../../datasets/REDS/train_sharp_flowx4' - lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb' - H_dst, W_dst = 360, 320 - n_thread = 40 - ######################################################## - if not lmdb_save_path.endswith('.lmdb'): - raise ValueError("lmdb_save_path must end with \'lmdb\'.") - if osp.exists(lmdb_save_path): - print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) - sys.exit(1) - - #### read all the image paths to a list - print('Reading image path list ...') - all_img_list = data_util._get_paths_from_images(img_folder) - keys = [] - for img_path in all_img_list: - split_rlt = img_path.split('/') - folder = split_rlt[-2] - img_name = split_rlt[-1].split('.png')[0] - keys.append(folder + '_' + img_name) - - if read_all_imgs: - #### read all images to memory (multiprocessing) - dataset = {} # store all image data. list cannot keep the order, use dict - print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) - pbar = util.ProgressBar(len(all_img_list)) - - def mycallback(arg): - '''get the image data and update pbar''' - key = arg[0] - dataset[key] = arg[1] - pbar.update('Reading {}'.format(key)) - - pool = Pool(n_thread) - for path, key in zip(all_img_list, keys): - pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) - pool.close() - pool.join() - print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) - - #### create lmdb environment - data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes - print('data size per image is: ', data_size_per_img) - data_size = data_size_per_img * len(all_img_list) - env = lmdb.open(lmdb_save_path, map_size=data_size * 10) - - #### write data to lmdb - pbar = util.ProgressBar(len(all_img_list)) - txn = env.begin(write=True) - for idx, (path, key) in enumerate(zip(all_img_list, keys)): - pbar.update('Write {}'.format(key)) - key_byte = key.encode('ascii') - data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) - if 'flow' in mode: - H, W = data.shape - assert H == H_dst and W == W_dst, 'different shape.' - else: - H, W, C = data.shape - assert H == H_dst and W == W_dst and C == 3, 'different shape.' - txn.put(key_byte, data) - if not read_all_imgs and idx % BATCH == 0: - txn.commit() - txn = env.begin(write=True) - txn.commit() - env.close() - print('Finish writing lmdb.') - - #### create meta information - meta_info = {} - meta_info['name'] = 'REDS_{}_wval'.format(mode) - channel = 1 if 'flow' in mode else 3 - meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst) - meta_info['keys'] = keys - pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) - print('Finish creating lmdb meta info.') - - -def test_lmdb(dataroot, dataset='REDS'): - env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False) - meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb")) - print('Name: ', meta_info['name']) - print('Resolution: ', meta_info['resolution']) - print('# keys: ', len(meta_info['keys'])) - # read one image - if dataset == 'vimeo90k': - key = '00001_0001_4' - else: - key = '000_00000000' - print('Reading {} for test.'.format(key)) - with env.begin(write=False) as txn: - buf = txn.get(key.encode('ascii')) - img_flat = np.frombuffer(buf, dtype=np.uint8) - C, H, W = [int(s) for s in meta_info['resolution'].split('_')] - img = img_flat.reshape(H, W, C) - cv2.imwrite('test.png', img) - - -if __name__ == "__main__": - main() diff --git a/codes/scripts/recover_tensorboard_log.py b/codes/scripts/recover_tensorboard_log.py deleted file mode 100644 index 3deef011..00000000 --- a/codes/scripts/recover_tensorboard_log.py +++ /dev/null @@ -1,22 +0,0 @@ -from torch.utils.tensorboard import SummaryWriter - -if __name__ == "__main__": - writer = SummaryWriter("../experiments/recovered_tb") - f = open("../experiments/recovered_tb.txt", encoding="utf8") - console = f.readlines() - search_terms = [ - ("iter", ", iter: ", ", lr:"), - ("l_g_total", " l_g_total: ", " switch_temperature:"), - ("l_d_fake", "l_d_fake: ", " D_fake:") - ] - iter = 0 - for line in console: - if " - INFO: [epoch:" not in line: - continue - for name, start, end in search_terms: - val = line[line.find(start)+len(start):line.find(end)].replace(",", "") - if name == "iter": - iter = int(val) - else: - writer.add_scalar(name, float(val), iter) - writer.close() \ No newline at end of file diff --git a/codes/scripts/rename.py b/codes/scripts/rename.py deleted file mode 100644 index ded86ed4..00000000 --- a/codes/scripts/rename.py +++ /dev/null @@ -1,19 +0,0 @@ -import os -import glob - - -def main(): - folder = 'datasets/div2k/DIV2K_valid_LR_bicubic/X4' - DIV2K(folder) - print('Finished.') - - -def DIV2K(path): - img_path_l = glob.glob(os.path.join(path, '*')) - for img_path in img_path_l: - new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '') - os.rename(img_path, new_path) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/codes/scripts/test_psnr_approximator.py b/codes/scripts/test_psnr_approximator.py deleted file mode 100644 index 680b9cb8..00000000 --- a/codes/scripts/test_psnr_approximator.py +++ /dev/null @@ -1,83 +0,0 @@ -import os.path as osp -import logging -import time -import argparse - -import os - -import torchvision - -import utils -import utils.options as option -import utils.util as util -from trainer.ExtensibleTrainer import ExtensibleTrainer -from data import create_dataset, create_dataloader -from tqdm import tqdm -import torch - -if __name__ == "__main__": - #### options - torch.backends.cudnn.benchmark = True - srg_analyze = False - parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_psnr_approximator.yml') - opt = option.parse(parser.parse_args().opt, is_train=False) - opt = option.dict_to_nonedict(opt) - utils.util.loaded_options = opt - - util.mkdirs( - (path for key, path in opt['path'].items() - if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) - util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, - screen=True, tofile=True) - logger = logging.getLogger('base') - logger.info(option.dict2str(opt)) - - #### Create test dataset and dataloader - test_loaders = [] - for phase, dataset_opt in sorted(opt['datasets'].items()): - dataset_opt['n_workers'] = 0 - test_set = create_dataset(dataset_opt) - test_loader = create_dataloader(test_set, dataset_opt, opt) - logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) - test_loaders.append(test_loader) - - model = ExtensibleTrainer(opt) - for test_loader in test_loaders: - test_set_name = test_loader.dataset.opt['name'] - logger.info('\nTesting [{:s}]...'.format(test_set_name)) - test_start_time = time.time() - dataset_dir = osp.join(opt['path']['results_root'], test_set_name) - util.mkdir(dataset_dir) - - dst_path = "F:\\playground" - [os.makedirs(osp.join(dst_path, str(i)), exist_ok=True) for i in range(10)] - - corruptions = ['none', 'color_quantization', 'gaussian_blur', 'motion_blur', 'smooth_blur', 'noise', - 'jpeg-medium', 'jpeg-broad', 'jpeg-normal', 'saturation', 'lq_resampling', - 'lq_resampling4x'] - c_counter = 0 - test_set.corruptor.num_corrupts = 0 - test_set.corruptor.random_corruptions = [] - test_set.corruptor.fixed_corruptions = [corruptions[0]] - corruption_mse = [(0,0) for _ in corruptions] - - tq = tqdm(test_loader) - batch_size = opt['datasets']['train']['batch_size'] - for data in tq: - need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True - model.feed_data(data, need_GT=need_GT) - model.test() - est_psnr = torch.mean(model.eval_state['psnr_approximate'][0], dim=[1,2,3]) - for i in range(est_psnr.shape[0]): - im_path = data['GT_path'][i] - torchvision.utils.save_image(model.eval_state['lq'][0][i], osp.join(dst_path, str(int(est_psnr[i]*10)), osp.basename(im_path))) - #shutil.copy(im_path, osp.join(dst_path, str(int(est_psnr[i]*10)))) - - last_mse, last_ctr = corruption_mse[c_counter % len(corruptions)] - corruption_mse[c_counter % len(corruptions)] = (last_mse + torch.sum(est_psnr).item(), last_ctr + 1) - c_counter += 1 - test_set.corruptor.fixed_corruptions = [corruptions[c_counter % len(corruptions)]] - if c_counter % 100 == 0: - for i, (mse, ctr) in enumerate(corruption_mse): - print("%s: %f" % (corruptions[i], mse / (ctr * batch_size))) \ No newline at end of file diff --git a/codes/utils/fdpl_util.py b/codes/utils/fdpl_util.py deleted file mode 100644 index 2ec8cd11..00000000 --- a/codes/utils/fdpl_util.py +++ /dev/null @@ -1,136 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn - -# note: all dct related functions are either exactly as or based on those -# at https://github.com/zh217/torch-dct -def dct(x, norm=None): - """ - Discrete Cosine Transform, Type II (a.k.a. the DCT) - For the meaning of the parameter `norm`, see: - https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html - :param x: the input signal - :param norm: the normalization, None or 'ortho' - :return: the DCT-II of the signal over the last dimension - """ - x_shape = x.shape - N = x_shape[-1] - x = x.contiguous().view(-1, N) - - v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) - - Vc = torch.rfft(v, 1, onesided=False) - - k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) - W_r = torch.cos(k) - W_i = torch.sin(k) - - V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i - - if norm == 'ortho': - V[:, 0] /= np.sqrt(N) * 2 - V[:, 1:] /= np.sqrt(N / 2) * 2 - - V = 2 * V.view(*x_shape) - - return V - -def idct(X, norm=None): - """ - The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III - Our definition of idct is that idct(dct(x)) == x - For the meaning of the parameter `norm`, see: - https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html - :param X: the input signal - :param norm: the normalization, None or 'ortho' - :return: the inverse DCT-II of the signal over the last dimension - """ - - x_shape = X.shape - N = x_shape[-1] - - X_v = X.contiguous().view(-1, x_shape[-1]) / 2 - - if norm == 'ortho': - X_v[:, 0] *= np.sqrt(N) * 2 - X_v[:, 1:] *= np.sqrt(N / 2) * 2 - - k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) - W_r = torch.cos(k) - W_i = torch.sin(k) - - V_t_r = X_v - V_t_r = V_t_r.to(device) - V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) - V_t_i = V_t_i.to(device) - - V_r = V_t_r * W_r - V_t_i * W_i - V_i = V_t_r * W_i + V_t_i * W_r - - V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) - - v = torch.irfft(V, 1, onesided=False) - x = v.new_zeros(v.shape) - x[:, ::2] += v[:, :N - (N // 2)] - x[:, 1::2] += v.flip([1])[:, :N // 2] - - return x.view(*x_shape) - -def dct_2d(x, norm=None): - """ - 2-dimensional Discrete Cosine Transform, Type II (a.k.a. the DCT) - For the meaning of the parameter `norm`, see: - https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html - :param x: the input signal - :param norm: the normalization, None or 'ortho' - :return: the DCT-II of the signal over the last 2 dimensions - """ - X1 = dct(x, norm=norm) - X2 = dct(X1.transpose(-1, -2), norm=norm) - return X2.transpose(-1, -2) - -def idct_2d(X, norm=None): - """ - The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III - Our definition of idct is that idct_2d(dct_2d(x)) == x - For the meaning of the parameter `norm`, see: - https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html - :param X: the input signal - :param norm: the normalization, None or 'ortho' - :return: the DCT-II of the signal over the last 2 dimensions - """ - x1 = idct(X, norm=norm) - x2 = idct(x1.transpose(-1, -2), norm=norm) - return x2.transpose(-1, -2) - -def extract_patches_2d(img,patch_shape,step=[1.0,1.0],batch_first=False): - """ - source: https://gist.github.com/dem123456789/23f18fd78ac8da9615c347905e64fc78 - """ - patch_H, patch_W = patch_shape[0], patch_shape[1] - if(img.size(2) < patch_H): - num_padded_H_Top = (patch_H - img.size(2))//2 - num_padded_H_Bottom = patch_H - img.size(2) - num_padded_H_Top - padding_H = nn.ConstantPad2d((0, 0, num_padded_H_Top, num_padded_H_Bottom), 0) - img = padding_H(img) - if(img.size(3) < patch_W): - num_padded_W_Left = (patch_W - img.size(3))//2 - num_padded_W_Right = patch_W - img.size(3) - num_padded_W_Left - padding_W = nn.ConstantPad2d((num_padded_W_Left,num_padded_W_Right, 0, 0), 0) - img = padding_W(img) - step_int = [0, 0] - step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0] - step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1] - patches_fold_H = img.unfold(2, patch_H, step_int[0]) - if((img.size(2) - patch_H) % step_int[0] != 0): - patches_fold_H = torch.cat((patches_fold_H, - img[:, :, -patch_H:, :].permute(0,1,3,2).unsqueeze(2)),dim=2) - patches_fold_HW = patches_fold_H.unfold(3, patch_W, step_int[1]) - if((img.size(3) - patch_W) % step_int[1] != 0): - patches_fold_HW = torch.cat((patches_fold_HW, - patches_fold_H[:, :, :, -patch_W:, :].permute(0, 1, 2, 4, 3).unsqueeze(3)), dim=3) - patches = patches_fold_HW.permute(2, 3, 0, 1, 4, 5) - patches = patches.reshape(-1, img.size(0), img.size(1), patch_H, patch_W) - if(batch_first): - patches = patches.permute(1, 0, 2, 3, 4) - return patches