036684893e
- Added LARS and SGD optimizer variants that support turning off certain features for BN and bias layers - Added a variant of pytorch's resnet model that supports gradient checkpointing. - Modify the trainer infrastructure to support above - Fix bug with BYOL (should have been nonfunctional)
110 lines
4.1 KiB
Python
110 lines
4.1 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
class LARC(object):
|
|
"""
|
|
:class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC,
|
|
in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive
|
|
local learning rate for each individual parameter. The algorithm is designed to improve
|
|
convergence of large batch training.
|
|
|
|
See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.
|
|
|
|
In practice it modifies the gradients of parameters as a proxy for modifying the learning rate
|
|
of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.
|
|
|
|
```
|
|
model = ...
|
|
optim = torch.optim.Adam(model.parameters(), lr=...)
|
|
optim = LARC(optim)
|
|
```
|
|
|
|
It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.
|
|
|
|
```
|
|
model = ...
|
|
optim = torch.optim.Adam(model.parameters(), lr=...)
|
|
optim = LARC(optim)
|
|
optim = apex.fp16_utils.FP16_Optimizer(optim)
|
|
```
|
|
|
|
Args:
|
|
optimizer: Pytorch optimizer to wrap and modify learning rate for.
|
|
trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888
|
|
clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
|
|
eps: epsilon kludge to help with numerical stability while calculating adaptive_lr
|
|
"""
|
|
|
|
def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
|
|
self.optim = optimizer
|
|
self.trust_coefficient = trust_coefficient
|
|
self.eps = eps
|
|
self.clip = clip
|
|
|
|
def __getstate__(self):
|
|
return self.optim.__getstate__()
|
|
|
|
def __setstate__(self, state):
|
|
self.optim.__setstate__(state)
|
|
|
|
@property
|
|
def state(self):
|
|
return self.optim.state
|
|
|
|
def __repr__(self):
|
|
return self.optim.__repr__()
|
|
|
|
@property
|
|
def param_groups(self):
|
|
return self.optim.param_groups
|
|
|
|
@param_groups.setter
|
|
def param_groups(self, value):
|
|
self.optim.param_groups = value
|
|
|
|
def state_dict(self):
|
|
return self.optim.state_dict()
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.optim.load_state_dict(state_dict)
|
|
|
|
def zero_grad(self):
|
|
self.optim.zero_grad()
|
|
|
|
def add_param_group(self, param_group):
|
|
self.optim.add_param_group(param_group)
|
|
|
|
def step(self):
|
|
with torch.no_grad():
|
|
weight_decays = []
|
|
for group in self.optim.param_groups:
|
|
# absorb weight decay control from optimizer
|
|
weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
|
|
weight_decays.append(weight_decay)
|
|
group['weight_decay'] = 0
|
|
for p in group['params']:
|
|
is_bn_or_bias = (hasattr(p, 'is_bn') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias)
|
|
if p.grad is None or is_bn_or_bias:
|
|
continue
|
|
param_norm = torch.norm(p.data)
|
|
grad_norm = torch.norm(p.grad.data)
|
|
|
|
if param_norm != 0 and grad_norm != 0:
|
|
# calculate adaptive lr + weight decay
|
|
adaptive_lr = self.trust_coefficient * (param_norm) / (
|
|
grad_norm + param_norm * weight_decay + self.eps)
|
|
|
|
# clip learning rate for LARC
|
|
if self.clip:
|
|
# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
|
|
adaptive_lr = min(adaptive_lr / group['lr'], 1)
|
|
|
|
p.grad.data += weight_decay * p.data
|
|
p.grad.data *= adaptive_lr
|
|
|
|
self.optim.step()
|
|
# return weight decay control to optimizer
|
|
for i, group in enumerate(self.optim.param_groups):
|
|
group['weight_decay'] = weight_decays[i] |