forked from mrq/DL-Art-School
Further mods to BYOL
This commit is contained in:
parent
036684893e
commit
29db7c7a02
|
@ -1,11 +1,15 @@
|
||||||
import copy
|
import copy
|
||||||
import random
|
import os
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
import kornia.augmentation as augs
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torchvision
|
||||||
|
from kornia import filters
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from data.byol_attachment import RandomApply
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -182,13 +186,25 @@ class BYOL(nn.Module):
|
||||||
projection_hidden_size=4096,
|
projection_hidden_size=4096,
|
||||||
moving_average_decay=0.99,
|
moving_average_decay=0.99,
|
||||||
use_momentum=True,
|
use_momentum=True,
|
||||||
structural_mlp=False
|
structural_mlp=False,
|
||||||
|
do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes
|
||||||
|
# this can overwhelm the CPU though, and it becomes desirable to do the augmentations
|
||||||
|
# on the GPU again.
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
|
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
|
||||||
use_structural_mlp=structural_mlp)
|
use_structural_mlp=structural_mlp)
|
||||||
|
|
||||||
|
self.do_aug = do_augmentation
|
||||||
|
if self.do_aug:
|
||||||
|
augmentations = [ \
|
||||||
|
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||||||
|
augs.RandomGrayscale(p=0.2),
|
||||||
|
augs.RandomHorizontalFlip(),
|
||||||
|
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
||||||
|
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
|
||||||
|
self.aug = nn.Sequential(*augmentations)
|
||||||
self.use_momentum = use_momentum
|
self.use_momentum = use_momentum
|
||||||
self.target_encoder = None
|
self.target_encoder = None
|
||||||
self.target_ema_updater = EMA(moving_average_decay)
|
self.target_ema_updater = EMA(moving_average_decay)
|
||||||
|
@ -221,9 +237,22 @@ class BYOL(nn.Module):
|
||||||
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
|
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
|
||||||
|
|
||||||
def get_debug_values(self, step, __):
|
def get_debug_values(self, step, __):
|
||||||
|
# In the BYOL paper, this is made to increase over time. Not yet implemented, but still logging the value.
|
||||||
return {'target_ema_beta': self.target_ema_updater.beta}
|
return {'target_ema_beta': self.target_ema_updater.beta}
|
||||||
|
|
||||||
|
def visual_dbg(self, step, path):
|
||||||
|
if self.do_aug:
|
||||||
|
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
|
||||||
|
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
||||||
|
|
||||||
def forward(self, image_one, image_two):
|
def forward(self, image_one, image_two):
|
||||||
|
if self.do_aug:
|
||||||
|
image_one = self.aug(image_one)
|
||||||
|
image_two = self.aug(image_two)
|
||||||
|
# Keep copies on hand for visual_dbg.
|
||||||
|
self.im1 = image_one.detach().copy()
|
||||||
|
self.im2 = image_two.detach().copy()
|
||||||
|
|
||||||
online_proj_one = self.online_encoder(image_one)
|
online_proj_one = self.online_encoder(image_one)
|
||||||
online_proj_two = self.online_encoder(image_two)
|
online_proj_two = self.online_encoder(image_two)
|
||||||
|
|
||||||
|
|
|
@ -293,7 +293,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_sameimage.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_diffimage.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -111,7 +111,8 @@ def define_G(opt, opt_net, scale=None):
|
||||||
from models.byol.byol_model_wrapper import BYOL
|
from models.byol.byol_model_wrapper import BYOL
|
||||||
subnet = define_G(opt, opt_net['subnet'])
|
subnet = define_G(opt, opt_net['subnet'])
|
||||||
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
|
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
|
||||||
|
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
|
||||||
elif which_model == 'structural_byol':
|
elif which_model == 'structural_byol':
|
||||||
from models.byol.byol_structural import StructuralBYOL
|
from models.byol.byol_structural import StructuralBYOL
|
||||||
subnet = define_G(opt, opt_net['subnet'])
|
subnet = define_G(opt, opt_net['subnet'])
|
||||||
|
|
|
@ -16,19 +16,31 @@ It is implemented via two wrappers:
|
||||||
Thanks to the excellent implementation from lucidrains, this wrapping process makes training your
|
Thanks to the excellent implementation from lucidrains, this wrapping process makes training your
|
||||||
network on unsupervised datasets extremely easy.
|
network on unsupervised datasets extremely easy.
|
||||||
|
|
||||||
Note: My intent is to adapt BYOL for use on structured models - e.g. models that do *not* collapse
|
The DLAS version improves on lucidrains implementation adding some important training details, such as
|
||||||
the latent into a flat map. Stay tuned for that..
|
a custom LARS optimizer implementation that aligns with the recommendations from the paper. By moving augmentation
|
||||||
|
to the dataset level, additional augmentation options are unlocked - like being able to take two similar video frames
|
||||||
|
as the image pair.
|
||||||
|
|
||||||
# Training BYOL
|
# Training BYOL
|
||||||
|
|
||||||
In this directory, you will find a sample training config for training BYOL on DIV2K. You will
|
In this directory, you will find a sample training config for training BYOL on DIV2K. You will
|
||||||
likely want to insert your own model architecture first. Exchange out spinenet for your
|
likely want to insert your own model architecture first.
|
||||||
model architecture and change the `hidden_layer` parameter to a layer from your network
|
|
||||||
that you want the BYOL model wrapper to hook into.
|
|
||||||
|
|
||||||
*hint: Your network architecture (including layer names) is printed out when running train.py
|
|
||||||
against your network.*
|
|
||||||
|
|
||||||
Run the trainer by:
|
Run the trainer by:
|
||||||
|
|
||||||
`python train.py -opt train_div2k_byol.yml`
|
`python train.py -opt train_div2k_byol.yml`
|
||||||
|
|
||||||
|
BYOL is data hungry, as most unsupervised training methods are. You'll definitely want to provide
|
||||||
|
your own dataset - DIV2K is here as an example only.
|
||||||
|
|
||||||
|
## Using your own model
|
||||||
|
|
||||||
|
Training your own model on this BYOL implementation is trivial:
|
||||||
|
1. Add your nn.Module model implementation to the models/ directory.
|
||||||
|
2. Register your model with `trainer/networks.py` as a generator. This file tells DLAS how to build your model from
|
||||||
|
a set of configuration options.
|
||||||
|
3. Copy the sample training config. Change the `subnet` and `hidden_layer` params.
|
||||||
|
4. Run your config with `python train.py -opt <your_config>`.
|
||||||
|
|
||||||
|
*hint: Your network architecture (including layer names) is printed out when running train.py
|
||||||
|
against your network.*
|
Loading…
Reference in New Issue
Block a user