diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 8af03960..ab6797a4 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -47,11 +47,10 @@ def create_dataset(dataset_opt): from data.image_folder_dataset import ImageFolderDataset as D elif mode == 'torch_dataset': from data.torch_dataset import TorchDataset as D + elif mode == 'byol_dataset': + from data.byol_attachment import ByolDatasetWrapper as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) - logger = logging.getLogger('base') - logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, - dataset_opt['name'])) return dataset diff --git a/codes/data/byol_attachment.py b/codes/data/byol_attachment.py new file mode 100644 index 00000000..75be36a6 --- /dev/null +++ b/codes/data/byol_attachment.py @@ -0,0 +1,47 @@ +import random + +import torch +from torch.utils.data import Dataset +from kornia import augmentation as augs +from kornia import filters +import torch.nn as nn + +# Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq' +# inputs. These are then outputted as 'aug1' and 'aug2'. +from data import create_dataset + + +class RandomApply(nn.Module): + def __init__(self, fn, p): + super().__init__() + self.fn = fn + self.p = p + def forward(self, x): + if random.random() > self.p: + return x + return self.fn(x) + + +class ByolDatasetWrapper(Dataset): + def __init__(self, opt): + super().__init__() + self.wrapped_dataset = create_dataset(opt['dataset']) + self.cropped_img_size = opt['crop_size'] + augmentations = [ \ + RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), + augs.RandomGrayscale(p=0.2), + augs.RandomHorizontalFlip(), + RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), + augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] + if opt['normalize']: + # The paper calls for normalization. Recommend setting true if you want exactly like the paper. + augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))) + self.aug = nn.Sequential(*augmentations) + + def __getitem__(self, item): + item = self.wrapped_dataset[item] + item.update({'aug1': self.aug(item['hq']).squeeze(dim=0), 'aug2': self.aug(item['lq']).squeeze(dim=0)}) + return item + + def __len__(self): + return len(self.wrapped_dataset) diff --git a/codes/models/byol/byol_model_wrapper.py b/codes/models/byol/byol_model_wrapper.py new file mode 100644 index 00000000..536faa85 --- /dev/null +++ b/codes/models/byol/byol_model_wrapper.py @@ -0,0 +1,237 @@ +import copy +import random +from functools import wraps + +import torch +import torch.nn.functional as F +from torch import nn + +from utils.util import checkpoint + + +def default(val, def_val): + return def_val if val is None else val + + +def flatten(t): + return t.reshape(t.shape[0], -1) + + +def singleton(cache_key): + def inner_fn(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + instance = getattr(self, cache_key) + if instance is not None: + return instance + + instance = fn(self, *args, **kwargs) + setattr(self, cache_key, instance) + return instance + + return wrapper + + return inner_fn + + +def get_module_device(module): + return next(module.parameters()).device + + +def set_requires_grad(model, val): + for p in model.parameters(): + p.requires_grad = val + + +# loss fn +def loss_fn(x, y): + x = F.normalize(x, dim=-1, p=2) + y = F.normalize(y, dim=-1, p=2) + return 2 - 2 * (x * y).sum(dim=-1) + + +# exponential moving average +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +def update_moving_average(ema_updater, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = ema_updater.update_average(old_weight, up_weight) + + +# MLP class for projector and predictor +class MLP(nn.Module): + def __init__(self, dim, projection_size, hidden_size=4096): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, projection_size) + ) + + def forward(self, x): + x = flatten(x) + return self.net(x) + + +# A wrapper class for training against networks that do not collapse into a small-dimensioned latent. +class StructuralMLP(nn.Module): + def __init__(self, dim, projection_size, hidden_size=4096): + super().__init__() + b, c, h, w = dim + flattened_dim = c * h // 4 * w // 4 + self.net = nn.Sequential( + nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2), + nn.BatchNorm2d(c), + nn.ReLU(inplace=True), + nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2), + nn.BatchNorm2d(c), + nn.ReLU(inplace=True), + nn.Flatten(), + nn.Linear(flattened_dim, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, projection_size) + ) + + def forward(self, x): + return self.net(x) + + +# a wrapper class for the base neural network +# will manage the interception of the hidden layer output +# and pipe it into the projecter and predictor nets +class NetWrapper(nn.Module): + def __init__(self, net, projection_size, projection_hidden_size, layer=-2, use_structural_mlp=False): + super().__init__() + self.net = net + self.layer = layer + + self.projector = None + self.projection_size = projection_size + self.projection_hidden_size = projection_hidden_size + self.structural_mlp = use_structural_mlp + + self.hidden = None + self.hook_registered = False + + def _find_layer(self): + if type(self.layer) == str: + modules = dict([*self.net.named_modules()]) + return modules.get(self.layer, None) + elif type(self.layer) == int: + children = [*self.net.children()] + return children[self.layer] + return None + + def _hook(self, _, __, output): + self.hidden = output + + def _register_hook(self): + layer = self._find_layer() + assert layer is not None, f'hidden layer ({self.layer}) not found' + handle = layer.register_forward_hook(self._hook) + self.hook_registered = True + + @singleton('projector') + def _get_projector(self, hidden): + if self.structural_mlp: + projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size) + else: + _, dim = hidden.shape + projector = MLP(dim, self.projection_size, self.projection_hidden_size) + return projector.to(hidden) + + def get_representation(self, x): + if self.layer == -1: + return self.net(x) + + if not self.hook_registered: + self._register_hook() + + unused = self.net(x) + hidden = self.hidden + self.hidden = None + assert hidden is not None, f'hidden layer {self.layer} never emitted an output' + return hidden + + def forward(self, x): + representation = self.get_representation(x) + projector = self._get_projector(representation) + projection = checkpoint(projector, representation) + return projection + + +class BYOL(nn.Module): + def __init__( + self, + net, + image_size, + hidden_layer=-2, + projection_size=256, + projection_hidden_size=4096, + moving_average_decay=0.99, + use_momentum=True, + structural_mlp=False + ): + super().__init__() + + self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, + use_structural_mlp=structural_mlp) + + self.use_momentum = use_momentum + self.target_encoder = None + self.target_ema_updater = EMA(moving_average_decay) + + self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) + + # get device of network and make wrapper same device + device = get_module_device(net) + self.to(device) + + # send a mock image tensor to instantiate singleton parameters + self.forward(torch.randn(2, 3, image_size, image_size, device=device), + torch.randn(2, 3, image_size, image_size, device=device)) + + @singleton('target_encoder') + def _get_target_encoder(self): + target_encoder = copy.deepcopy(self.online_encoder) + set_requires_grad(target_encoder, False) + return target_encoder + + def reset_moving_average(self): + del self.target_encoder + self.target_encoder = None + + def update_for_step(self, step, __): + assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder' + assert self.target_encoder is not None, 'target encoder has not been created yet' + update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) + + def forward(self, image_one, image_two): + online_proj_one = self.online_encoder(image_one) + online_proj_two = self.online_encoder(image_two) + + online_pred_one = self.online_predictor(online_proj_one) + online_pred_two = self.online_predictor(online_proj_two) + + with torch.no_grad(): + target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder + target_proj_one = target_encoder(image_one).detach() + target_proj_two = target_encoder(image_two).detach() + + loss_one = loss_fn(online_pred_one, target_proj_two.detach()) + loss_two = loss_fn(online_pred_two, target_proj_one.detach()) + + loss = loss_one + loss_two + return loss.mean() diff --git a/codes/models/networks.py b/codes/models/networks.py index c77d435e..da129007 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -21,6 +21,7 @@ from models.archs import srg2_classic from models.archs.biggan.biggan_discriminator import BigGanDiscriminator from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator from models.archs.teco_resgen import TecoGen +from utils.util import opt_get logger = logging.getLogger('base') @@ -147,6 +148,14 @@ def define_G(opt, opt_net, scale=None): elif which_model == 'igpt2': from models.archs.transformers.igpt.gpt2 import iGPT2 netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file']) + elif which_model == 'byol': + from models.byol.byol_model_wrapper import BYOL + subnet = define_G(opt, opt_net['subnet']) + netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], + structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False)) + elif which_model == 'spinenet': + from models.archs.spinenet_arch import SpineNet + netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG diff --git a/recipes/byol/README.md b/recipes/byol/README.md new file mode 100644 index 00000000..f8ca9eb7 --- /dev/null +++ b/recipes/byol/README.md @@ -0,0 +1,34 @@ +# Working with BYOL in DLAS + +[BYOL](https://arxiv.org/abs/2006.07733) is a technique for pretraining an arbitrary image processing +neural network. It is built upon previous self-supervised architectures like SimCLR. + +BYOL in DLAS is adapted from an implementation written by [lucidrains](https://github.com/lucidrains/byol-pytorch). +It is implemented via two wrappers: + +1. A Dataset wrapper that augments the LQ and HQ inputs from a typical DLAS dataset. Since differentiable + augmentations don't actually matter for BYOL, it makes more sense (to me) to do this on the CPU at the + dataset layer, so your GPU can focus on processing gradients. +1. A model wrapper that attaches a small MLP to the end of your input network to produce a fixed + size latent. This latent is used to produce the BYOL loss which trains the master weights from + your network. + +Thanks to the excellent implementation from lucidrains, this wrapping process makes training your +network on unsupervised datasets extremely easy. + +Note: My intent is to adapt BYOL for use on structured models - e.g. models that do *not* collapse +the latent into a flat map. Stay tuned for that.. + +# Training BYOL + +In this directory, you will find a sample training config for training BYOL on DIV2K. You will +likely want to insert your own model architecture first. Exchange out spinenet for your +model architecture and change the `hidden_layer` parameter to a layer from your network +that you want the BYOL model wrapper to hook into. + +*hint: Your network architecture (including layer names) is printed out when running train.py +against your network.* + +Run the trainer by: + +`python train.py -opt train_div2k_byol.yml` \ No newline at end of file diff --git a/recipes/byol/train_div2k_byol.yml b/recipes/byol/train_div2k_byol.yml new file mode 100644 index 00000000..407b3d39 --- /dev/null +++ b/recipes/byol/train_div2k_byol.yml @@ -0,0 +1,86 @@ +#### general settings +name: train_div2k_byol +use_tb_logger: true +model: extensibletrainer +distortion: sr +scale: 1 +gpu_ids: [0] +fp16: false +start_step: 0 +checkpointing_enabled: true # <-- Highly recommended for single-GPU training. Will not work with DDP. +wandb: false + +datasets: + train: + n_workers: 4 + batch_size: 32 + mode: byol_dataset + crop_size: 256 + normalize: true + dataset: + mode: imagefolder + paths: /content/div2k # <-- Put your path here. Note: full images. + target_size: 256 + scale: 1 + +networks: + generator: + type: generator + which_model_G: byol + image_size: 256 + subnet: # <-- Specify your own network to pretrain here. + which_model_G: spinenet + arch: 49 + use_input_norm: true + + hidden_layer: endpoint_convs.4.conv # <-- Specify a hidden layer from your network here. + +#### path +path: + #pretrain_model_generator: + strict_load: true + #resume_state: ../experiments/train_div2k_byol/training_state/0.state # <-- Set this to resume from a previous training state. + +steps: + generator: + training: generator + + optimizer_params: + # Optimizer params + lr: !!float 3e-4 + weight_decay: 0 + beta1: 0.9 + beta2: 0.99 + + injectors: + gen_inj: + type: generator + generator: generator + in: [aug1, aug2] + out: loss + + losses: + byol_loss: + type: direct + key: loss + weight: 1 + +train: + niter: 500000 + warmup_iter: -1 + mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. + val_freq: 2000 + + # Default LR scheduler options + default_lr_scheme: MultiStepLR + gen_lr_steps: [50000, 100000, 150000, 200000] + lr_gamma: 0.5 + +eval: + output_state: loss + +logger: + print_freq: 30 + save_checkpoint_freq: 1000 + visuals: [hq, lq, aug1, aug2] + visual_debug_rate: 100 \ No newline at end of file