diff --git a/codes/models/byol/byol_model_wrapper.py b/codes/models/byol/byol_model_wrapper.py index 5dc82ce2..398db439 100644 --- a/codes/models/byol/byol_model_wrapper.py +++ b/codes/models/byol/byol_model_wrapper.py @@ -1,11 +1,15 @@ import copy -import random +import os from functools import wraps +import kornia.augmentation as augs import torch import torch.nn.functional as F +import torchvision +from kornia import filters from torch import nn +from data.byol_attachment import RandomApply from utils.util import checkpoint @@ -182,13 +186,25 @@ class BYOL(nn.Module): projection_hidden_size=4096, moving_average_decay=0.99, use_momentum=True, - structural_mlp=False + structural_mlp=False, + do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes + # this can overwhelm the CPU though, and it becomes desirable to do the augmentations + # on the GPU again. ): super().__init__() self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, use_structural_mlp=structural_mlp) + self.do_aug = do_augmentation + if self.do_aug: + augmentations = [ \ + RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), + augs.RandomGrayscale(p=0.2), + augs.RandomHorizontalFlip(), + RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), + augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] + self.aug = nn.Sequential(*augmentations) self.use_momentum = use_momentum self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) @@ -221,9 +237,22 @@ class BYOL(nn.Module): update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) def get_debug_values(self, step, __): + # In the BYOL paper, this is made to increase over time. Not yet implemented, but still logging the value. return {'target_ema_beta': self.target_ema_updater.beta} + def visual_dbg(self, step, path): + if self.do_aug: + torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,))) + torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,))) + def forward(self, image_one, image_two): + if self.do_aug: + image_one = self.aug(image_one) + image_two = self.aug(image_two) + # Keep copies on hand for visual_dbg. + self.im1 = image_one.detach().copy() + self.im2 = image_two.detach().copy() + online_proj_one = self.online_encoder(image_one) online_proj_two = self.online_encoder(image_two) diff --git a/codes/train.py b/codes/train.py index 5262d5b7..86d45886 100644 --- a/codes/train.py +++ b/codes/train.py @@ -293,7 +293,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_sameimage.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_diffimage.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py index 44695c44..7c67eb67 100644 --- a/codes/trainer/networks.py +++ b/codes/trainer/networks.py @@ -111,7 +111,8 @@ def define_G(opt, opt_net, scale=None): from models.byol.byol_model_wrapper import BYOL subnet = define_G(opt, opt_net['subnet']) netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], - structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False)) + structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False), + do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False)) elif which_model == 'structural_byol': from models.byol.byol_structural import StructuralBYOL subnet = define_G(opt, opt_net['subnet']) diff --git a/recipes/byol/README.md b/recipes/byol/README.md index f8ca9eb7..4559b7fa 100644 --- a/recipes/byol/README.md +++ b/recipes/byol/README.md @@ -16,19 +16,31 @@ It is implemented via two wrappers: Thanks to the excellent implementation from lucidrains, this wrapping process makes training your network on unsupervised datasets extremely easy. -Note: My intent is to adapt BYOL for use on structured models - e.g. models that do *not* collapse -the latent into a flat map. Stay tuned for that.. +The DLAS version improves on lucidrains implementation adding some important training details, such as +a custom LARS optimizer implementation that aligns with the recommendations from the paper. By moving augmentation +to the dataset level, additional augmentation options are unlocked - like being able to take two similar video frames +as the image pair. # Training BYOL In this directory, you will find a sample training config for training BYOL on DIV2K. You will -likely want to insert your own model architecture first. Exchange out spinenet for your -model architecture and change the `hidden_layer` parameter to a layer from your network -that you want the BYOL model wrapper to hook into. - -*hint: Your network architecture (including layer names) is printed out when running train.py -against your network.* +likely want to insert your own model architecture first. Run the trainer by: -`python train.py -opt train_div2k_byol.yml` \ No newline at end of file +`python train.py -opt train_div2k_byol.yml` + +BYOL is data hungry, as most unsupervised training methods are. You'll definitely want to provide +your own dataset - DIV2K is here as an example only. + +## Using your own model + +Training your own model on this BYOL implementation is trivial: +1. Add your nn.Module model implementation to the models/ directory. +2. Register your model with `trainer/networks.py` as a generator. This file tells DLAS how to build your model from + a set of configuration options. +3. Copy the sample training config. Change the `subnet` and `hidden_layer` params. +4. Run your config with `python train.py -opt `. + +*hint: Your network architecture (including layer names) is printed out when running train.py +against your network.* \ No newline at end of file