Further mods to BYOL
This commit is contained in:
parent
036684893e
commit
29db7c7a02
|
@ -1,11 +1,15 @@
|
|||
import copy
|
||||
import random
|
||||
import os
|
||||
from functools import wraps
|
||||
import kornia.augmentation as augs
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from kornia import filters
|
||||
from torch import nn
|
||||
|
||||
from data.byol_attachment import RandomApply
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
|
@ -182,13 +186,25 @@ class BYOL(nn.Module):
|
|||
projection_hidden_size=4096,
|
||||
moving_average_decay=0.99,
|
||||
use_momentum=True,
|
||||
structural_mlp=False
|
||||
structural_mlp=False,
|
||||
do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes
|
||||
# this can overwhelm the CPU though, and it becomes desirable to do the augmentations
|
||||
# on the GPU again.
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
|
||||
use_structural_mlp=structural_mlp)
|
||||
|
||||
self.do_aug = do_augmentation
|
||||
if self.do_aug:
|
||||
augmentations = [ \
|
||||
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
augs.RandomHorizontalFlip(),
|
||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
||||
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
|
||||
self.aug = nn.Sequential(*augmentations)
|
||||
self.use_momentum = use_momentum
|
||||
self.target_encoder = None
|
||||
self.target_ema_updater = EMA(moving_average_decay)
|
||||
|
@ -221,9 +237,22 @@ class BYOL(nn.Module):
|
|||
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
# In the BYOL paper, this is made to increase over time. Not yet implemented, but still logging the value.
|
||||
return {'target_ema_beta': self.target_ema_updater.beta}
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
if self.do_aug:
|
||||
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
|
||||
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
||||
|
||||
def forward(self, image_one, image_two):
|
||||
if self.do_aug:
|
||||
image_one = self.aug(image_one)
|
||||
image_two = self.aug(image_two)
|
||||
# Keep copies on hand for visual_dbg.
|
||||
self.im1 = image_one.detach().copy()
|
||||
self.im2 = image_two.detach().copy()
|
||||
|
||||
online_proj_one = self.online_encoder(image_one)
|
||||
online_proj_two = self.online_encoder(image_two)
|
||||
|
||||
|
|
|
@ -293,7 +293,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_sameimage.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_diffimage.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -111,7 +111,8 @@ def define_G(opt, opt_net, scale=None):
|
|||
from models.byol.byol_model_wrapper import BYOL
|
||||
subnet = define_G(opt, opt_net['subnet'])
|
||||
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
|
||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
|
||||
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
|
||||
elif which_model == 'structural_byol':
|
||||
from models.byol.byol_structural import StructuralBYOL
|
||||
subnet = define_G(opt, opt_net['subnet'])
|
||||
|
|
|
@ -16,19 +16,31 @@ It is implemented via two wrappers:
|
|||
Thanks to the excellent implementation from lucidrains, this wrapping process makes training your
|
||||
network on unsupervised datasets extremely easy.
|
||||
|
||||
Note: My intent is to adapt BYOL for use on structured models - e.g. models that do *not* collapse
|
||||
the latent into a flat map. Stay tuned for that..
|
||||
The DLAS version improves on lucidrains implementation adding some important training details, such as
|
||||
a custom LARS optimizer implementation that aligns with the recommendations from the paper. By moving augmentation
|
||||
to the dataset level, additional augmentation options are unlocked - like being able to take two similar video frames
|
||||
as the image pair.
|
||||
|
||||
# Training BYOL
|
||||
|
||||
In this directory, you will find a sample training config for training BYOL on DIV2K. You will
|
||||
likely want to insert your own model architecture first. Exchange out spinenet for your
|
||||
model architecture and change the `hidden_layer` parameter to a layer from your network
|
||||
that you want the BYOL model wrapper to hook into.
|
||||
|
||||
*hint: Your network architecture (including layer names) is printed out when running train.py
|
||||
against your network.*
|
||||
likely want to insert your own model architecture first.
|
||||
|
||||
Run the trainer by:
|
||||
|
||||
`python train.py -opt train_div2k_byol.yml`
|
||||
`python train.py -opt train_div2k_byol.yml`
|
||||
|
||||
BYOL is data hungry, as most unsupervised training methods are. You'll definitely want to provide
|
||||
your own dataset - DIV2K is here as an example only.
|
||||
|
||||
## Using your own model
|
||||
|
||||
Training your own model on this BYOL implementation is trivial:
|
||||
1. Add your nn.Module model implementation to the models/ directory.
|
||||
2. Register your model with `trainer/networks.py` as a generator. This file tells DLAS how to build your model from
|
||||
a set of configuration options.
|
||||
3. Copy the sample training config. Change the `subnet` and `hidden_layer` params.
|
||||
4. Run your config with `python train.py -opt <your_config>`.
|
||||
|
||||
*hint: Your network architecture (including layer names) is printed out when running train.py
|
||||
against your network.*
|
Loading…
Reference in New Issue
Block a user