Further mods to BYOL

This commit is contained in:
James Betker 2020-12-24 09:28:41 -07:00
parent 036684893e
commit 29db7c7a02
4 changed files with 55 additions and 13 deletions

View File

@ -1,11 +1,15 @@
import copy import copy
import random import os
from functools import wraps from functools import wraps
import kornia.augmentation as augs
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchvision
from kornia import filters
from torch import nn from torch import nn
from data.byol_attachment import RandomApply
from utils.util import checkpoint from utils.util import checkpoint
@ -182,13 +186,25 @@ class BYOL(nn.Module):
projection_hidden_size=4096, projection_hidden_size=4096,
moving_average_decay=0.99, moving_average_decay=0.99,
use_momentum=True, use_momentum=True,
structural_mlp=False structural_mlp=False,
do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes
# this can overwhelm the CPU though, and it becomes desirable to do the augmentations
# on the GPU again.
): ):
super().__init__() super().__init__()
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
use_structural_mlp=structural_mlp) use_structural_mlp=structural_mlp)
self.do_aug = do_augmentation
if self.do_aug:
augmentations = [ \
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
augs.RandomHorizontalFlip(),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
self.aug = nn.Sequential(*augmentations)
self.use_momentum = use_momentum self.use_momentum = use_momentum
self.target_encoder = None self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay) self.target_ema_updater = EMA(moving_average_decay)
@ -221,9 +237,22 @@ class BYOL(nn.Module):
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
def get_debug_values(self, step, __): def get_debug_values(self, step, __):
# In the BYOL paper, this is made to increase over time. Not yet implemented, but still logging the value.
return {'target_ema_beta': self.target_ema_updater.beta} return {'target_ema_beta': self.target_ema_updater.beta}
def visual_dbg(self, step, path):
if self.do_aug:
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
def forward(self, image_one, image_two): def forward(self, image_one, image_two):
if self.do_aug:
image_one = self.aug(image_one)
image_two = self.aug(image_two)
# Keep copies on hand for visual_dbg.
self.im1 = image_one.detach().copy()
self.im2 = image_two.detach().copy()
online_proj_one = self.online_encoder(image_one) online_proj_one = self.online_encoder(image_one)
online_proj_two = self.online_encoder(image_two) online_proj_two = self.online_encoder(image_two)

View File

@ -293,7 +293,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_sameimage.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_diffimage.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -111,7 +111,8 @@ def define_G(opt, opt_net, scale=None):
from models.byol.byol_model_wrapper import BYOL from models.byol.byol_model_wrapper import BYOL
subnet = define_G(opt, opt_net['subnet']) subnet = define_G(opt, opt_net['subnet'])
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False)) structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
elif which_model == 'structural_byol': elif which_model == 'structural_byol':
from models.byol.byol_structural import StructuralBYOL from models.byol.byol_structural import StructuralBYOL
subnet = define_G(opt, opt_net['subnet']) subnet = define_G(opt, opt_net['subnet'])

View File

@ -16,19 +16,31 @@ It is implemented via two wrappers:
Thanks to the excellent implementation from lucidrains, this wrapping process makes training your Thanks to the excellent implementation from lucidrains, this wrapping process makes training your
network on unsupervised datasets extremely easy. network on unsupervised datasets extremely easy.
Note: My intent is to adapt BYOL for use on structured models - e.g. models that do *not* collapse The DLAS version improves on lucidrains implementation adding some important training details, such as
the latent into a flat map. Stay tuned for that.. a custom LARS optimizer implementation that aligns with the recommendations from the paper. By moving augmentation
to the dataset level, additional augmentation options are unlocked - like being able to take two similar video frames
as the image pair.
# Training BYOL # Training BYOL
In this directory, you will find a sample training config for training BYOL on DIV2K. You will In this directory, you will find a sample training config for training BYOL on DIV2K. You will
likely want to insert your own model architecture first. Exchange out spinenet for your likely want to insert your own model architecture first.
model architecture and change the `hidden_layer` parameter to a layer from your network
that you want the BYOL model wrapper to hook into.
*hint: Your network architecture (including layer names) is printed out when running train.py
against your network.*
Run the trainer by: Run the trainer by:
`python train.py -opt train_div2k_byol.yml` `python train.py -opt train_div2k_byol.yml`
BYOL is data hungry, as most unsupervised training methods are. You'll definitely want to provide
your own dataset - DIV2K is here as an example only.
## Using your own model
Training your own model on this BYOL implementation is trivial:
1. Add your nn.Module model implementation to the models/ directory.
2. Register your model with `trainer/networks.py` as a generator. This file tells DLAS how to build your model from
a set of configuration options.
3. Copy the sample training config. Change the `subnet` and `hidden_layer` params.
4. Run your config with `python train.py -opt <your_config>`.
*hint: Your network architecture (including layer names) is printed out when running train.py
against your network.*