Further mods to BYOL

This commit is contained in:
James Betker 2020-12-24 09:28:41 -07:00
parent 036684893e
commit 29db7c7a02
4 changed files with 55 additions and 13 deletions

View File

@ -1,11 +1,15 @@
import copy
import random
import os
from functools import wraps
import kornia.augmentation as augs
import torch
import torch.nn.functional as F
import torchvision
from kornia import filters
from torch import nn
from data.byol_attachment import RandomApply
from utils.util import checkpoint
@ -182,13 +186,25 @@ class BYOL(nn.Module):
projection_hidden_size=4096,
moving_average_decay=0.99,
use_momentum=True,
structural_mlp=False
structural_mlp=False,
do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes
# this can overwhelm the CPU though, and it becomes desirable to do the augmentations
# on the GPU again.
):
super().__init__()
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
use_structural_mlp=structural_mlp)
self.do_aug = do_augmentation
if self.do_aug:
augmentations = [ \
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
augs.RandomHorizontalFlip(),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
self.aug = nn.Sequential(*augmentations)
self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
@ -221,9 +237,22 @@ class BYOL(nn.Module):
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
def get_debug_values(self, step, __):
# In the BYOL paper, this is made to increase over time. Not yet implemented, but still logging the value.
return {'target_ema_beta': self.target_ema_updater.beta}
def visual_dbg(self, step, path):
if self.do_aug:
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
def forward(self, image_one, image_two):
if self.do_aug:
image_one = self.aug(image_one)
image_two = self.aug(image_two)
# Keep copies on hand for visual_dbg.
self.im1 = image_one.detach().copy()
self.im2 = image_two.detach().copy()
online_proj_one = self.online_encoder(image_one)
online_proj_two = self.online_encoder(image_two)

View File

@ -293,7 +293,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_sameimage.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_diffimage.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -111,7 +111,8 @@ def define_G(opt, opt_net, scale=None):
from models.byol.byol_model_wrapper import BYOL
subnet = define_G(opt, opt_net['subnet'])
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
elif which_model == 'structural_byol':
from models.byol.byol_structural import StructuralBYOL
subnet = define_G(opt, opt_net['subnet'])

View File

@ -16,19 +16,31 @@ It is implemented via two wrappers:
Thanks to the excellent implementation from lucidrains, this wrapping process makes training your
network on unsupervised datasets extremely easy.
Note: My intent is to adapt BYOL for use on structured models - e.g. models that do *not* collapse
the latent into a flat map. Stay tuned for that..
The DLAS version improves on lucidrains implementation adding some important training details, such as
a custom LARS optimizer implementation that aligns with the recommendations from the paper. By moving augmentation
to the dataset level, additional augmentation options are unlocked - like being able to take two similar video frames
as the image pair.
# Training BYOL
In this directory, you will find a sample training config for training BYOL on DIV2K. You will
likely want to insert your own model architecture first. Exchange out spinenet for your
model architecture and change the `hidden_layer` parameter to a layer from your network
that you want the BYOL model wrapper to hook into.
*hint: Your network architecture (including layer names) is printed out when running train.py
against your network.*
likely want to insert your own model architecture first.
Run the trainer by:
`python train.py -opt train_div2k_byol.yml`
BYOL is data hungry, as most unsupervised training methods are. You'll definitely want to provide
your own dataset - DIV2K is here as an example only.
## Using your own model
Training your own model on this BYOL implementation is trivial:
1. Add your nn.Module model implementation to the models/ directory.
2. Register your model with `trainer/networks.py` as a generator. This file tells DLAS how to build your model from
a set of configuration options.
3. Copy the sample training config. Change the `subnet` and `hidden_layer` params.
4. Run your config with `python train.py -opt <your_config>`.
*hint: Your network architecture (including layer names) is printed out when running train.py
against your network.*