Man, is there anything ExtensibleTrainer can't train? :)
This commit is contained in:
James Betker 2020-12-08 13:07:53 -07:00
parent 5369cba8ed
commit 97ff25a086
6 changed files with 415 additions and 3 deletions

View File

@ -47,11 +47,10 @@ def create_dataset(dataset_opt):
from data.image_folder_dataset import ImageFolderDataset as D from data.image_folder_dataset import ImageFolderDataset as D
elif mode == 'torch_dataset': elif mode == 'torch_dataset':
from data.torch_dataset import TorchDataset as D from data.torch_dataset import TorchDataset as D
elif mode == 'byol_dataset':
from data.byol_attachment import ByolDatasetWrapper as D
else: else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt) dataset = D(dataset_opt)
logger = logging.getLogger('base')
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
dataset_opt['name']))
return dataset return dataset

View File

@ -0,0 +1,47 @@
import random
import torch
from torch.utils.data import Dataset
from kornia import augmentation as augs
from kornia import filters
import torch.nn as nn
# Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq'
# inputs. These are then outputted as 'aug1' and 'aug2'.
from data import create_dataset
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
class ByolDatasetWrapper(Dataset):
def __init__(self, opt):
super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset'])
self.cropped_img_size = opt['crop_size']
augmentations = [ \
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
augs.RandomHorizontalFlip(),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
if opt['normalize']:
# The paper calls for normalization. Recommend setting true if you want exactly like the paper.
augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
self.aug = nn.Sequential(*augmentations)
def __getitem__(self, item):
item = self.wrapped_dataset[item]
item.update({'aug1': self.aug(item['hq']).squeeze(dim=0), 'aug2': self.aug(item['lq']).squeeze(dim=0)})
return item
def __len__(self):
return len(self.wrapped_dataset)

View File

@ -0,0 +1,237 @@
import copy
import random
from functools import wraps
import torch
import torch.nn.functional as F
from torch import nn
from utils.util import checkpoint
def default(val, def_val):
return def_val if val is None else val
def flatten(t):
return t.reshape(t.shape[0], -1)
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
def get_module_device(module):
return next(module.parameters()).device
def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val
# loss fn
def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
# exponential moving average
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
# MLP class for projector and predictor
class MLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size=4096):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)
def forward(self, x):
x = flatten(x)
return self.net(x)
# A wrapper class for training against networks that do not collapse into a small-dimensioned latent.
class StructuralMLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size=4096):
super().__init__()
b, c, h, w = dim
flattened_dim = c * h // 4 * w // 4
self.net = nn.Sequential(
nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2),
nn.BatchNorm2d(c),
nn.ReLU(inplace=True),
nn.Conv2d(c, c, kernel_size=3, padding=1, stride=2),
nn.BatchNorm2d(c),
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(flattened_dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)
def forward(self, x):
return self.net(x)
# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets
class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer=-2, use_structural_mlp=False):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size
self.structural_mlp = use_structural_mlp
self.hidden = None
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, __, output):
self.hidden = output
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
if self.structural_mlp:
projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size)
else:
_, dim = hidden.shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden)
def get_representation(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
unused = self.net(x)
hidden = self.hidden
self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x):
representation = self.get_representation(x)
projector = self._get_projector(representation)
projection = checkpoint(projector, representation)
return projection
class BYOL(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer=-2,
projection_size=256,
projection_hidden_size=4096,
moving_average_decay=0.99,
use_momentum=True,
structural_mlp=False
):
super().__init__()
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
use_structural_mlp=structural_mlp)
self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device),
torch.randn(2, 3, image_size, image_size, device=device))
@singleton('target_encoder')
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder
def reset_moving_average(self):
del self.target_encoder
self.target_encoder = None
def update_for_step(self, step, __):
assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
def forward(self, image_one, image_two):
online_proj_one = self.online_encoder(image_one)
online_proj_two = self.online_encoder(image_two)
online_pred_one = self.online_predictor(online_proj_one)
online_pred_two = self.online_predictor(online_proj_two)
with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one = target_encoder(image_one).detach()
target_proj_two = target_encoder(image_two).detach()
loss_one = loss_fn(online_pred_one, target_proj_two.detach())
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
loss = loss_one + loss_two
return loss.mean()

View File

@ -21,6 +21,7 @@ from models.archs import srg2_classic
from models.archs.biggan.biggan_discriminator import BigGanDiscriminator from models.archs.biggan.biggan_discriminator import BigGanDiscriminator
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
from models.archs.teco_resgen import TecoGen from models.archs.teco_resgen import TecoGen
from utils.util import opt_get
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -147,6 +148,14 @@ def define_G(opt, opt_net, scale=None):
elif which_model == 'igpt2': elif which_model == 'igpt2':
from models.archs.transformers.igpt.gpt2 import iGPT2 from models.archs.transformers.igpt.gpt2 import iGPT2
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file']) netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
elif which_model == 'byol':
from models.byol.byol_model_wrapper import BYOL
subnet = define_G(opt, opt_net['subnet'])
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
elif which_model == 'spinenet':
from models.archs.spinenet_arch import SpineNet
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
else: else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG return netG

34
recipes/byol/README.md Normal file
View File

@ -0,0 +1,34 @@
# Working with BYOL in DLAS
[BYOL](https://arxiv.org/abs/2006.07733) is a technique for pretraining an arbitrary image processing
neural network. It is built upon previous self-supervised architectures like SimCLR.
BYOL in DLAS is adapted from an implementation written by [lucidrains](https://github.com/lucidrains/byol-pytorch).
It is implemented via two wrappers:
1. A Dataset wrapper that augments the LQ and HQ inputs from a typical DLAS dataset. Since differentiable
augmentations don't actually matter for BYOL, it makes more sense (to me) to do this on the CPU at the
dataset layer, so your GPU can focus on processing gradients.
1. A model wrapper that attaches a small MLP to the end of your input network to produce a fixed
size latent. This latent is used to produce the BYOL loss which trains the master weights from
your network.
Thanks to the excellent implementation from lucidrains, this wrapping process makes training your
network on unsupervised datasets extremely easy.
Note: My intent is to adapt BYOL for use on structured models - e.g. models that do *not* collapse
the latent into a flat map. Stay tuned for that..
# Training BYOL
In this directory, you will find a sample training config for training BYOL on DIV2K. You will
likely want to insert your own model architecture first. Exchange out spinenet for your
model architecture and change the `hidden_layer` parameter to a layer from your network
that you want the BYOL model wrapper to hook into.
*hint: Your network architecture (including layer names) is printed out when running train.py
against your network.*
Run the trainer by:
`python train.py -opt train_div2k_byol.yml`

View File

@ -0,0 +1,86 @@
#### general settings
name: train_div2k_byol
use_tb_logger: true
model: extensibletrainer
distortion: sr
scale: 1
gpu_ids: [0]
fp16: false
start_step: 0
checkpointing_enabled: true # <-- Highly recommended for single-GPU training. Will not work with DDP.
wandb: false
datasets:
train:
n_workers: 4
batch_size: 32
mode: byol_dataset
crop_size: 256
normalize: true
dataset:
mode: imagefolder
paths: /content/div2k # <-- Put your path here. Note: full images.
target_size: 256
scale: 1
networks:
generator:
type: generator
which_model_G: byol
image_size: 256
subnet: # <-- Specify your own network to pretrain here.
which_model_G: spinenet
arch: 49
use_input_norm: true
hidden_layer: endpoint_convs.4.conv # <-- Specify a hidden layer from your network here.
#### path
path:
#pretrain_model_generator: <insert pretrained model path if desired>
strict_load: true
#resume_state: ../experiments/train_div2k_byol/training_state/0.state # <-- Set this to resume from a previous training state.
steps:
generator:
training: generator
optimizer_params:
# Optimizer params
lr: !!float 3e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
gen_inj:
type: generator
generator: generator
in: [aug1, aug2]
out: loss
losses:
byol_loss:
type: direct
key: loss
weight: 1
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
val_freq: 2000
# Default LR scheduler options
default_lr_scheme: MultiStepLR
gen_lr_steps: [50000, 100000, 150000, 200000]
lr_gamma: 0.5
eval:
output_state: loss
logger:
print_freq: 30
save_checkpoint_freq: 1000
visuals: [hq, lq, aug1, aug2]
visual_debug_rate: 100