Add LARS optimizer & support for BYOL idiosyncrasies

- Added LARS and SGD optimizer variants that support turning off certain
  features for BN and bias layers
- Added a variant of pytorch's resnet model that supports gradient checkpointing.
- Modify the trainer infrastructure to support above
- Fix bug with BYOL (should have been nonfunctional)
This commit is contained in:
James Betker 2020-12-23 20:33:43 -07:00
parent 1bbcb96ee8
commit 036684893e
8 changed files with 420 additions and 20 deletions

View File

@ -207,6 +207,8 @@ class BYOL(nn.Module):
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
for p in target_encoder.parameters():
p.DO_NOT_TRAIN = True
return target_encoder
def reset_moving_average(self):
@ -218,6 +220,9 @@ class BYOL(nn.Module):
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
def get_debug_values(self, step, __):
return {'target_ema_beta': self.target_ema_updater.beta}
def forward(self, image_one, image_two):
online_proj_one = self.online_encoder(image_one)
online_proj_two = self.online_encoder(image_two)

View File

@ -0,0 +1,190 @@
# A direct copy of torchvision's resnet.py modified to support gradient checkpointing.
import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck
from torchvision.models.utils import load_state_dict_from_url
import torchvision
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
from utils.util import checkpoint
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
class ResNet(torchvision.models.resnet.ResNet):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
replace_stride_with_dilation, norm_layer)
def _forward_impl(self, x):
# Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl,
# except using checkpoints on the body conv layers.
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = checkpoint(self.layer1, x)
x = checkpoint(self.layer2, x)
x = checkpoint(self.layer3, x)
x = checkpoint(self.layer4, x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)

View File

@ -99,7 +99,6 @@ class ExtensibleTrainer(BaseModel):
else:
self.schedulers = []
# Wrap networks in distributed shells.
dnets = []
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
@ -318,6 +317,10 @@ class ExtensibleTrainer(BaseModel):
for net_name, net in self.networks.items():
if hasattr(net.module, "get_debug_values"):
log.update(net.module.get_debug_values(step, net_name))
# Log learning rate (from first param group) too.
for o in self.optimizers:
log['learning_rate_%s' % (o._config['network'],)] = o.param_groups[0]['lr']
return log
def get_current_visuals(self, need_GT=True):

View File

@ -8,6 +8,10 @@ from torch.optim.lr_scheduler import _LRScheduler
def get_scheduler_for_name(name, optimizers, scheduler_opt):
schedulers = []
for o in optimizers:
# Hack to support LARC, which wraps an underlying optimizer.
if hasattr(o, 'optim'):
o = o.optim
if name == 'MultiStepLR':
sched = MultiStepLR_Restart(o, scheduler_opt['gen_lr_steps'],
restarts=scheduler_opt['restarts'],
@ -21,7 +25,7 @@ def get_scheduler_for_name(name, optimizers, scheduler_opt):
scheduler_opt['lr_gamma'])
elif name == 'CosineAnnealingLR_Restart':
sched = CosineAnnealingLR_Restart(
o, scheduler_opt['T_period'], eta_min=scheduler_opt['eta_min'],
o, scheduler_opt['T_period'], scheduler_opt['warmup'], eta_min=scheduler_opt['eta_min'],
restarts=scheduler_opt['restarts'], weights=scheduler_opt['restart_weights'])
else:
raise NotImplementedError('Scheduler not available')
@ -86,7 +90,8 @@ class MultiStepLR_Restart(_LRScheduler):
class CosineAnnealingLR_Restart(_LRScheduler):
def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
def __init__(self, optimizer, T_period, warmup=0, restarts=None, weights=None, eta_min=0, last_epoch=-1):
self.warmup = warmup
self.T_period = T_period
self.T_max = self.T_period[0] # current T period
self.eta_min = eta_min
@ -99,26 +104,27 @@ class CosineAnnealingLR_Restart(_LRScheduler):
super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch == 0:
step = self.last_epoch - self.warmup
if step <= 0:
return self.base_lrs
elif self.last_epoch in self.restarts:
self.last_restart = self.last_epoch
self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
elif step in self.restarts:
self.last_restart = step
self.T_max = self.T_period[self.restarts.index(step) + 1]
weight = self.restart_weights[self.restarts.index(step)]
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
elif (step - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
return [
group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
(1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
return [(1 + math.cos(math.pi * (step - self.last_restart) / self.T_max)) /
(1 + math.cos(math.pi * ((step - self.last_restart) - 1) / self.T_max)) *
(group['lr'] - self.eta_min) + self.eta_min
for group in self.optimizer.param_groups]
if __name__ == "__main__":
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=.2, weight_decay=0,
betas=(0.9, 0.99))
##############################
# MultiStepLR_Restart
@ -153,11 +159,11 @@ if __name__ == "__main__":
restart_weights = [1]
## four
T_period = [250000, 250000, 250000, 250000]
restarts = [250000, 500000, 750000]
restart_weights = [1, 1, 1]
T_period = [80000, 80000, 80000, 80000]
restarts = [100000, 200000]
restart_weights = [.5, .25]
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, warmup=100000, eta_min=.01, restarts=restarts,
weights=restart_weights)
##############################

View File

@ -126,7 +126,8 @@ def define_G(opt, opt_net, scale=None):
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
in_channels=3, use_input_norm=opt_net['use_input_norm'])
elif which_model == 'resnet52':
netG = torchvision.models.resnet50(pretrained=opt_net['pretrained'])
from models.resnet_with_checkpointing import resnet50
netG = resnet50(pretrained=opt_net['pretrained'])
elif which_model == 'glean':
from models.glean.glean import GleanGenerator
netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])

View File

@ -0,0 +1,110 @@
import torch
from torch import nn
from torch.nn.parameter import Parameter
class LARC(object):
"""
:class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC,
in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive
local learning rate for each individual parameter. The algorithm is designed to improve
convergence of large batch training.
See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.
In practice it modifies the gradients of parameters as a proxy for modifying the learning rate
of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.
```
model = ...
optim = torch.optim.Adam(model.parameters(), lr=...)
optim = LARC(optim)
```
It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.
```
model = ...
optim = torch.optim.Adam(model.parameters(), lr=...)
optim = LARC(optim)
optim = apex.fp16_utils.FP16_Optimizer(optim)
```
Args:
optimizer: Pytorch optimizer to wrap and modify learning rate for.
trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888
clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
eps: epsilon kludge to help with numerical stability while calculating adaptive_lr
"""
def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
self.optim = optimizer
self.trust_coefficient = trust_coefficient
self.eps = eps
self.clip = clip
def __getstate__(self):
return self.optim.__getstate__()
def __setstate__(self, state):
self.optim.__setstate__(state)
@property
def state(self):
return self.optim.state
def __repr__(self):
return self.optim.__repr__()
@property
def param_groups(self):
return self.optim.param_groups
@param_groups.setter
def param_groups(self, value):
self.optim.param_groups = value
def state_dict(self):
return self.optim.state_dict()
def load_state_dict(self, state_dict):
self.optim.load_state_dict(state_dict)
def zero_grad(self):
self.optim.zero_grad()
def add_param_group(self, param_group):
self.optim.add_param_group(param_group)
def step(self):
with torch.no_grad():
weight_decays = []
for group in self.optim.param_groups:
# absorb weight decay control from optimizer
weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
weight_decays.append(weight_decay)
group['weight_decay'] = 0
for p in group['params']:
is_bn_or_bias = (hasattr(p, 'is_bn') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias)
if p.grad is None or is_bn_or_bias:
continue
param_norm = torch.norm(p.data)
grad_norm = torch.norm(p.grad.data)
if param_norm != 0 and grad_norm != 0:
# calculate adaptive lr + weight decay
adaptive_lr = self.trust_coefficient * (param_norm) / (
grad_norm + param_norm * weight_decay + self.eps)
# clip learning rate for LARC
if self.clip:
# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
adaptive_lr = min(adaptive_lr / group['lr'], 1)
p.grad.data += weight_decay * p.data
p.grad.data *= adaptive_lr
self.optim.step()
# return weight decay control to optimizer
for i, group in enumerate(self.optim.param_groups):
group['weight_decay'] = weight_decays[i]

View File

@ -0,0 +1,72 @@
import torch
from torch.optim import Optimizer
class SGDNoBiasMomentum(Optimizer):
r"""
Copy of pytorch implementation of SGD with a modification which turns off momentum for params marked
with `is_bn` or `is_bias`.
"""
def __init__(self, params, lr, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad
if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay)
# **this is the only modification over standard torch.optim.SGD:
is_bn_or_bias = (hasattr(p, 'is_bn') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias)
if not is_bn_or_bias and momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
p.add_(d_p, alpha=-group['lr'])
return loss

View File

@ -66,6 +66,16 @@ class ConfigurableStep(Module):
for net_name, net, opt_config in zip(training, nets, opt_configs):
optim_params = []
for k, v in net.named_parameters(): # can optimize for a part of the model
# Make some inference about these parameters, which can be used by some optimizers to treat certain
# parameters differently. For example, it is considered good practice to not do weight decay on
# BN & bias parameters. TODO: process the module tree instead of the parameter tree to accomplish the
# same thing, but in a more effective way.
if k.endswith(".bias"):
v.is_bias = True
if k.endswith(".weight"):
v.is_weight = True
if ".bn" in k or '.batchnorm' in k or '.bnorm' in k:
v.is_bn = True
if v.requires_grad:
optim_params.append(v)
else:
@ -76,9 +86,12 @@ class ConfigurableStep(Module):
opt = torch.optim.Adam(optim_params, lr=opt_config['lr'],
weight_decay=opt_config['weight_decay'],
betas=(opt_config['beta1'], opt_config['beta2']))
elif self.step_opt['optimizer'] == 'novograd':
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
betas=(opt_config['beta1'], opt_config['beta2']))
elif self.step_opt['optimizer'] == 'lars':
from trainer.optimizers.larc import LARC
from trainer.optimizers.sgd import SGDNoBiasMomentum
optSGD = SGDNoBiasMomentum(optim_params, lr=opt_config['lr'], momentum=opt_config['momentum'],
weight_decay=opt_config['weight_decay'])
opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient'])
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
opt._config['network'] = net_name
self.optimizers.append(opt)