Add LARS optimizer & support for BYOL idiosyncrasies
- Added LARS and SGD optimizer variants that support turning off certain features for BN and bias layers - Added a variant of pytorch's resnet model that supports gradient checkpointing. - Modify the trainer infrastructure to support above - Fix bug with BYOL (should have been nonfunctional)
This commit is contained in:
parent
1bbcb96ee8
commit
036684893e
|
@ -207,6 +207,8 @@ class BYOL(nn.Module):
|
||||||
def _get_target_encoder(self):
|
def _get_target_encoder(self):
|
||||||
target_encoder = copy.deepcopy(self.online_encoder)
|
target_encoder = copy.deepcopy(self.online_encoder)
|
||||||
set_requires_grad(target_encoder, False)
|
set_requires_grad(target_encoder, False)
|
||||||
|
for p in target_encoder.parameters():
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
return target_encoder
|
return target_encoder
|
||||||
|
|
||||||
def reset_moving_average(self):
|
def reset_moving_average(self):
|
||||||
|
@ -218,6 +220,9 @@ class BYOL(nn.Module):
|
||||||
assert self.target_encoder is not None, 'target encoder has not been created yet'
|
assert self.target_encoder is not None, 'target encoder has not been created yet'
|
||||||
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
|
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
|
||||||
|
|
||||||
|
def get_debug_values(self, step, __):
|
||||||
|
return {'target_ema_beta': self.target_ema_updater.beta}
|
||||||
|
|
||||||
def forward(self, image_one, image_two):
|
def forward(self, image_one, image_two):
|
||||||
online_proj_one = self.online_encoder(image_one)
|
online_proj_one = self.online_encoder(image_one)
|
||||||
online_proj_two = self.online_encoder(image_two)
|
online_proj_two = self.online_encoder(image_two)
|
||||||
|
|
190
codes/models/resnet_with_checkpointing.py
Normal file
190
codes/models/resnet_with_checkpointing.py
Normal file
|
@ -0,0 +1,190 @@
|
||||||
|
# A direct copy of torchvision's resnet.py modified to support gradient checkpointing.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||||
|
from torchvision.models.utils import load_state_dict_from_url
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||||
|
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||||
|
'wide_resnet50_2', 'wide_resnet101_2']
|
||||||
|
|
||||||
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
model_urls = {
|
||||||
|
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||||
|
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||||
|
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||||
|
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||||
|
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||||
|
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||||
|
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||||
|
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||||
|
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(torchvision.models.resnet.ResNet):
|
||||||
|
|
||||||
|
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
||||||
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||||
|
norm_layer=None):
|
||||||
|
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
|
||||||
|
replace_stride_with_dilation, norm_layer)
|
||||||
|
|
||||||
|
def _forward_impl(self, x):
|
||||||
|
# Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl,
|
||||||
|
# except using checkpoints on the body conv layers.
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = checkpoint(self.layer1, x)
|
||||||
|
x = checkpoint(self.layer2, x)
|
||||||
|
x = checkpoint(self.layer3, x)
|
||||||
|
x = checkpoint(self.layer4, x)
|
||||||
|
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
x = self.fc(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||||
|
model = ResNet(block, layers, **kwargs)
|
||||||
|
if pretrained:
|
||||||
|
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||||
|
progress=progress)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resnet18(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-18 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-34 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-50 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-101 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNet-152 model from
|
||||||
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNeXt-50 32x4d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 4
|
||||||
|
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""ResNeXt-101 32x8d model from
|
||||||
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 8
|
||||||
|
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""Wide ResNet-50-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||||
|
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||||
|
r"""Wide ResNet-101-2 model from
|
||||||
|
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||||
|
|
||||||
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['width_per_group'] = 64 * 2
|
||||||
|
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||||
|
pretrained, progress, **kwargs)
|
|
@ -99,7 +99,6 @@ class ExtensibleTrainer(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.schedulers = []
|
self.schedulers = []
|
||||||
|
|
||||||
|
|
||||||
# Wrap networks in distributed shells.
|
# Wrap networks in distributed shells.
|
||||||
dnets = []
|
dnets = []
|
||||||
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||||
|
@ -318,6 +317,10 @@ class ExtensibleTrainer(BaseModel):
|
||||||
for net_name, net in self.networks.items():
|
for net_name, net in self.networks.items():
|
||||||
if hasattr(net.module, "get_debug_values"):
|
if hasattr(net.module, "get_debug_values"):
|
||||||
log.update(net.module.get_debug_values(step, net_name))
|
log.update(net.module.get_debug_values(step, net_name))
|
||||||
|
|
||||||
|
# Log learning rate (from first param group) too.
|
||||||
|
for o in self.optimizers:
|
||||||
|
log['learning_rate_%s' % (o._config['network'],)] = o.param_groups[0]['lr']
|
||||||
return log
|
return log
|
||||||
|
|
||||||
def get_current_visuals(self, need_GT=True):
|
def get_current_visuals(self, need_GT=True):
|
||||||
|
|
|
@ -8,6 +8,10 @@ from torch.optim.lr_scheduler import _LRScheduler
|
||||||
def get_scheduler_for_name(name, optimizers, scheduler_opt):
|
def get_scheduler_for_name(name, optimizers, scheduler_opt):
|
||||||
schedulers = []
|
schedulers = []
|
||||||
for o in optimizers:
|
for o in optimizers:
|
||||||
|
# Hack to support LARC, which wraps an underlying optimizer.
|
||||||
|
if hasattr(o, 'optim'):
|
||||||
|
o = o.optim
|
||||||
|
|
||||||
if name == 'MultiStepLR':
|
if name == 'MultiStepLR':
|
||||||
sched = MultiStepLR_Restart(o, scheduler_opt['gen_lr_steps'],
|
sched = MultiStepLR_Restart(o, scheduler_opt['gen_lr_steps'],
|
||||||
restarts=scheduler_opt['restarts'],
|
restarts=scheduler_opt['restarts'],
|
||||||
|
@ -21,7 +25,7 @@ def get_scheduler_for_name(name, optimizers, scheduler_opt):
|
||||||
scheduler_opt['lr_gamma'])
|
scheduler_opt['lr_gamma'])
|
||||||
elif name == 'CosineAnnealingLR_Restart':
|
elif name == 'CosineAnnealingLR_Restart':
|
||||||
sched = CosineAnnealingLR_Restart(
|
sched = CosineAnnealingLR_Restart(
|
||||||
o, scheduler_opt['T_period'], eta_min=scheduler_opt['eta_min'],
|
o, scheduler_opt['T_period'], scheduler_opt['warmup'], eta_min=scheduler_opt['eta_min'],
|
||||||
restarts=scheduler_opt['restarts'], weights=scheduler_opt['restart_weights'])
|
restarts=scheduler_opt['restarts'], weights=scheduler_opt['restart_weights'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Scheduler not available')
|
raise NotImplementedError('Scheduler not available')
|
||||||
|
@ -86,7 +90,8 @@ class MultiStepLR_Restart(_LRScheduler):
|
||||||
|
|
||||||
|
|
||||||
class CosineAnnealingLR_Restart(_LRScheduler):
|
class CosineAnnealingLR_Restart(_LRScheduler):
|
||||||
def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
|
def __init__(self, optimizer, T_period, warmup=0, restarts=None, weights=None, eta_min=0, last_epoch=-1):
|
||||||
|
self.warmup = warmup
|
||||||
self.T_period = T_period
|
self.T_period = T_period
|
||||||
self.T_max = self.T_period[0] # current T period
|
self.T_max = self.T_period[0] # current T period
|
||||||
self.eta_min = eta_min
|
self.eta_min = eta_min
|
||||||
|
@ -99,26 +104,27 @@ class CosineAnnealingLR_Restart(_LRScheduler):
|
||||||
super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
|
super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
if self.last_epoch == 0:
|
step = self.last_epoch - self.warmup
|
||||||
|
if step <= 0:
|
||||||
return self.base_lrs
|
return self.base_lrs
|
||||||
elif self.last_epoch in self.restarts:
|
elif step in self.restarts:
|
||||||
self.last_restart = self.last_epoch
|
self.last_restart = step
|
||||||
self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
|
self.T_max = self.T_period[self.restarts.index(step) + 1]
|
||||||
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
weight = self.restart_weights[self.restarts.index(step)]
|
||||||
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
||||||
elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
|
elif (step - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
|
||||||
return [
|
return [
|
||||||
group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
|
group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
|
||||||
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
||||||
]
|
]
|
||||||
return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
|
return [(1 + math.cos(math.pi * (step - self.last_restart) / self.T_max)) /
|
||||||
(1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
|
(1 + math.cos(math.pi * ((step - self.last_restart) - 1) / self.T_max)) *
|
||||||
(group['lr'] - self.eta_min) + self.eta_min
|
(group['lr'] - self.eta_min) + self.eta_min
|
||||||
for group in self.optimizer.param_groups]
|
for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
|
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=.2, weight_decay=0,
|
||||||
betas=(0.9, 0.99))
|
betas=(0.9, 0.99))
|
||||||
##############################
|
##############################
|
||||||
# MultiStepLR_Restart
|
# MultiStepLR_Restart
|
||||||
|
@ -153,11 +159,11 @@ if __name__ == "__main__":
|
||||||
restart_weights = [1]
|
restart_weights = [1]
|
||||||
|
|
||||||
## four
|
## four
|
||||||
T_period = [250000, 250000, 250000, 250000]
|
T_period = [80000, 80000, 80000, 80000]
|
||||||
restarts = [250000, 500000, 750000]
|
restarts = [100000, 200000]
|
||||||
restart_weights = [1, 1, 1]
|
restart_weights = [.5, .25]
|
||||||
|
|
||||||
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
|
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, warmup=100000, eta_min=.01, restarts=restarts,
|
||||||
weights=restart_weights)
|
weights=restart_weights)
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
|
|
|
@ -126,7 +126,8 @@ def define_G(opt, opt_net, scale=None):
|
||||||
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
|
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
|
||||||
in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||||
elif which_model == 'resnet52':
|
elif which_model == 'resnet52':
|
||||||
netG = torchvision.models.resnet50(pretrained=opt_net['pretrained'])
|
from models.resnet_with_checkpointing import resnet50
|
||||||
|
netG = resnet50(pretrained=opt_net['pretrained'])
|
||||||
elif which_model == 'glean':
|
elif which_model == 'glean':
|
||||||
from models.glean.glean import GleanGenerator
|
from models.glean.glean import GleanGenerator
|
||||||
netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])
|
netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan'])
|
||||||
|
|
110
codes/trainer/optimizers/larc.py
Normal file
110
codes/trainer/optimizers/larc.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class LARC(object):
|
||||||
|
"""
|
||||||
|
:class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC,
|
||||||
|
in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive
|
||||||
|
local learning rate for each individual parameter. The algorithm is designed to improve
|
||||||
|
convergence of large batch training.
|
||||||
|
|
||||||
|
See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.
|
||||||
|
|
||||||
|
In practice it modifies the gradients of parameters as a proxy for modifying the learning rate
|
||||||
|
of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.
|
||||||
|
|
||||||
|
```
|
||||||
|
model = ...
|
||||||
|
optim = torch.optim.Adam(model.parameters(), lr=...)
|
||||||
|
optim = LARC(optim)
|
||||||
|
```
|
||||||
|
|
||||||
|
It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.
|
||||||
|
|
||||||
|
```
|
||||||
|
model = ...
|
||||||
|
optim = torch.optim.Adam(model.parameters(), lr=...)
|
||||||
|
optim = LARC(optim)
|
||||||
|
optim = apex.fp16_utils.FP16_Optimizer(optim)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: Pytorch optimizer to wrap and modify learning rate for.
|
||||||
|
trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888
|
||||||
|
clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
|
||||||
|
eps: epsilon kludge to help with numerical stability while calculating adaptive_lr
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
|
||||||
|
self.optim = optimizer
|
||||||
|
self.trust_coefficient = trust_coefficient
|
||||||
|
self.eps = eps
|
||||||
|
self.clip = clip
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return self.optim.__getstate__()
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.optim.__setstate__(state)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return self.optim.state
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.optim.__repr__()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
return self.optim.param_groups
|
||||||
|
|
||||||
|
@param_groups.setter
|
||||||
|
def param_groups(self, value):
|
||||||
|
self.optim.param_groups = value
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.optim.state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self.optim.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
self.optim.zero_grad()
|
||||||
|
|
||||||
|
def add_param_group(self, param_group):
|
||||||
|
self.optim.add_param_group(param_group)
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
weight_decays = []
|
||||||
|
for group in self.optim.param_groups:
|
||||||
|
# absorb weight decay control from optimizer
|
||||||
|
weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
|
||||||
|
weight_decays.append(weight_decay)
|
||||||
|
group['weight_decay'] = 0
|
||||||
|
for p in group['params']:
|
||||||
|
is_bn_or_bias = (hasattr(p, 'is_bn') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias)
|
||||||
|
if p.grad is None or is_bn_or_bias:
|
||||||
|
continue
|
||||||
|
param_norm = torch.norm(p.data)
|
||||||
|
grad_norm = torch.norm(p.grad.data)
|
||||||
|
|
||||||
|
if param_norm != 0 and grad_norm != 0:
|
||||||
|
# calculate adaptive lr + weight decay
|
||||||
|
adaptive_lr = self.trust_coefficient * (param_norm) / (
|
||||||
|
grad_norm + param_norm * weight_decay + self.eps)
|
||||||
|
|
||||||
|
# clip learning rate for LARC
|
||||||
|
if self.clip:
|
||||||
|
# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
|
||||||
|
adaptive_lr = min(adaptive_lr / group['lr'], 1)
|
||||||
|
|
||||||
|
p.grad.data += weight_decay * p.data
|
||||||
|
p.grad.data *= adaptive_lr
|
||||||
|
|
||||||
|
self.optim.step()
|
||||||
|
# return weight decay control to optimizer
|
||||||
|
for i, group in enumerate(self.optim.param_groups):
|
||||||
|
group['weight_decay'] = weight_decays[i]
|
72
codes/trainer/optimizers/sgd.py
Normal file
72
codes/trainer/optimizers/sgd.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
import torch
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class SGDNoBiasMomentum(Optimizer):
|
||||||
|
r"""
|
||||||
|
Copy of pytorch implementation of SGD with a modification which turns off momentum for params marked
|
||||||
|
with `is_bn` or `is_bias`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params, lr, momentum=0, dampening=0,
|
||||||
|
weight_decay=0, nesterov=False):
|
||||||
|
if lr < 0.0:
|
||||||
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
|
if momentum < 0.0:
|
||||||
|
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||||
|
if weight_decay < 0.0:
|
||||||
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||||
|
|
||||||
|
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
|
||||||
|
weight_decay=weight_decay, nesterov=nesterov)
|
||||||
|
if nesterov and (momentum <= 0 or dampening != 0):
|
||||||
|
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super().__setstate__(state)
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault('nesterov', False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
weight_decay = group['weight_decay']
|
||||||
|
momentum = group['momentum']
|
||||||
|
dampening = group['dampening']
|
||||||
|
nesterov = group['nesterov']
|
||||||
|
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
d_p = p.grad
|
||||||
|
if weight_decay != 0:
|
||||||
|
d_p = d_p.add(p, alpha=weight_decay)
|
||||||
|
# **this is the only modification over standard torch.optim.SGD:
|
||||||
|
is_bn_or_bias = (hasattr(p, 'is_bn') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias)
|
||||||
|
if not is_bn_or_bias and momentum != 0:
|
||||||
|
param_state = self.state[p]
|
||||||
|
if 'momentum_buffer' not in param_state:
|
||||||
|
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
|
||||||
|
else:
|
||||||
|
buf = param_state['momentum_buffer']
|
||||||
|
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
|
||||||
|
if nesterov:
|
||||||
|
d_p = d_p.add(buf, alpha=momentum)
|
||||||
|
else:
|
||||||
|
d_p = buf
|
||||||
|
|
||||||
|
p.add_(d_p, alpha=-group['lr'])
|
||||||
|
|
||||||
|
return loss
|
|
@ -66,6 +66,16 @@ class ConfigurableStep(Module):
|
||||||
for net_name, net, opt_config in zip(training, nets, opt_configs):
|
for net_name, net, opt_config in zip(training, nets, opt_configs):
|
||||||
optim_params = []
|
optim_params = []
|
||||||
for k, v in net.named_parameters(): # can optimize for a part of the model
|
for k, v in net.named_parameters(): # can optimize for a part of the model
|
||||||
|
# Make some inference about these parameters, which can be used by some optimizers to treat certain
|
||||||
|
# parameters differently. For example, it is considered good practice to not do weight decay on
|
||||||
|
# BN & bias parameters. TODO: process the module tree instead of the parameter tree to accomplish the
|
||||||
|
# same thing, but in a more effective way.
|
||||||
|
if k.endswith(".bias"):
|
||||||
|
v.is_bias = True
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
v.is_weight = True
|
||||||
|
if ".bn" in k or '.batchnorm' in k or '.bnorm' in k:
|
||||||
|
v.is_bn = True
|
||||||
if v.requires_grad:
|
if v.requires_grad:
|
||||||
optim_params.append(v)
|
optim_params.append(v)
|
||||||
else:
|
else:
|
||||||
|
@ -76,9 +86,12 @@ class ConfigurableStep(Module):
|
||||||
opt = torch.optim.Adam(optim_params, lr=opt_config['lr'],
|
opt = torch.optim.Adam(optim_params, lr=opt_config['lr'],
|
||||||
weight_decay=opt_config['weight_decay'],
|
weight_decay=opt_config['weight_decay'],
|
||||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||||
elif self.step_opt['optimizer'] == 'novograd':
|
elif self.step_opt['optimizer'] == 'lars':
|
||||||
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
|
from trainer.optimizers.larc import LARC
|
||||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
from trainer.optimizers.sgd import SGDNoBiasMomentum
|
||||||
|
optSGD = SGDNoBiasMomentum(optim_params, lr=opt_config['lr'], momentum=opt_config['momentum'],
|
||||||
|
weight_decay=opt_config['weight_decay'])
|
||||||
|
opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient'])
|
||||||
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
|
opt._config = opt_config # This is a bit seedy, but we will need these configs later.
|
||||||
opt._config['network'] = net_name
|
opt._config['network'] = net_name
|
||||||
self.optimizers.append(opt)
|
self.optimizers.append(opt)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user