diff --git a/codes/models/byol/byol_model_wrapper.py b/codes/models/byol/byol_model_wrapper.py index 3727e296..5dc82ce2 100644 --- a/codes/models/byol/byol_model_wrapper.py +++ b/codes/models/byol/byol_model_wrapper.py @@ -207,6 +207,8 @@ class BYOL(nn.Module): def _get_target_encoder(self): target_encoder = copy.deepcopy(self.online_encoder) set_requires_grad(target_encoder, False) + for p in target_encoder.parameters(): + p.DO_NOT_TRAIN = True return target_encoder def reset_moving_average(self): @@ -218,6 +220,9 @@ class BYOL(nn.Module): assert self.target_encoder is not None, 'target encoder has not been created yet' update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) + def get_debug_values(self, step, __): + return {'target_ema_beta': self.target_ema_updater.beta} + def forward(self, image_one, image_two): online_proj_one = self.online_encoder(image_one) online_proj_two = self.online_encoder(image_two) diff --git a/codes/models/resnet_with_checkpointing.py b/codes/models/resnet_with_checkpointing.py new file mode 100644 index 00000000..39a5523f --- /dev/null +++ b/codes/models/resnet_with_checkpointing.py @@ -0,0 +1,190 @@ +# A direct copy of torchvision's resnet.py modified to support gradient checkpointing. + +import torch +import torch.nn as nn +from torchvision.models.resnet import BasicBlock, Bottleneck +from torchvision.models.utils import load_state_dict_from_url +import torchvision + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + +from utils.util import checkpoint + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +class ResNet(torchvision.models.resnet.ResNet): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, + replace_stride_with_dilation, norm_layer) + + def _forward_impl(self, x): + # Should be the exact same implementation of torchvision.models.resnet.ResNet.forward_impl, + # except using checkpoints on the body conv layers. + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = checkpoint(self.layer1, x) + x = checkpoint(self.layer2, x) + x = checkpoint(self.layer3, x) + x = checkpoint(self.layer4, x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x): + return self._forward_impl(x) + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 3a9ea335..ffec11b5 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -99,7 +99,6 @@ class ExtensibleTrainer(BaseModel): else: self.schedulers = [] - # Wrap networks in distributed shells. dnets = [] all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] @@ -318,6 +317,10 @@ class ExtensibleTrainer(BaseModel): for net_name, net in self.networks.items(): if hasattr(net.module, "get_debug_values"): log.update(net.module.get_debug_values(step, net_name)) + + # Log learning rate (from first param group) too. + for o in self.optimizers: + log['learning_rate_%s' % (o._config['network'],)] = o.param_groups[0]['lr'] return log def get_current_visuals(self, need_GT=True): diff --git a/codes/trainer/lr_scheduler.py b/codes/trainer/lr_scheduler.py index f299b87f..307cec27 100644 --- a/codes/trainer/lr_scheduler.py +++ b/codes/trainer/lr_scheduler.py @@ -8,6 +8,10 @@ from torch.optim.lr_scheduler import _LRScheduler def get_scheduler_for_name(name, optimizers, scheduler_opt): schedulers = [] for o in optimizers: + # Hack to support LARC, which wraps an underlying optimizer. + if hasattr(o, 'optim'): + o = o.optim + if name == 'MultiStepLR': sched = MultiStepLR_Restart(o, scheduler_opt['gen_lr_steps'], restarts=scheduler_opt['restarts'], @@ -21,7 +25,7 @@ def get_scheduler_for_name(name, optimizers, scheduler_opt): scheduler_opt['lr_gamma']) elif name == 'CosineAnnealingLR_Restart': sched = CosineAnnealingLR_Restart( - o, scheduler_opt['T_period'], eta_min=scheduler_opt['eta_min'], + o, scheduler_opt['T_period'], scheduler_opt['warmup'], eta_min=scheduler_opt['eta_min'], restarts=scheduler_opt['restarts'], weights=scheduler_opt['restart_weights']) else: raise NotImplementedError('Scheduler not available') @@ -86,7 +90,8 @@ class MultiStepLR_Restart(_LRScheduler): class CosineAnnealingLR_Restart(_LRScheduler): - def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): + def __init__(self, optimizer, T_period, warmup=0, restarts=None, weights=None, eta_min=0, last_epoch=-1): + self.warmup = warmup self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min @@ -99,26 +104,27 @@ class CosineAnnealingLR_Restart(_LRScheduler): super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) def get_lr(self): - if self.last_epoch == 0: + step = self.last_epoch - self.warmup + if step <= 0: return self.base_lrs - elif self.last_epoch in self.restarts: - self.last_restart = self.last_epoch - self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] - weight = self.restart_weights[self.restarts.index(self.last_epoch)] + elif step in self.restarts: + self.last_restart = step + self.T_max = self.T_period[self.restarts.index(step) + 1] + weight = self.restart_weights[self.restarts.index(step)] return [group['initial_lr'] * weight for group in self.optimizer.param_groups] - elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: + elif (step - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: return [ group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] - return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / - (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * + return [(1 + math.cos(math.pi * (step - self.last_restart) / self.T_max)) / + (1 + math.cos(math.pi * ((step - self.last_restart) - 1) / self.T_max)) * (group['lr'] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups] if __name__ == "__main__": - optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, + optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=.2, weight_decay=0, betas=(0.9, 0.99)) ############################## # MultiStepLR_Restart @@ -153,11 +159,11 @@ if __name__ == "__main__": restart_weights = [1] ## four - T_period = [250000, 250000, 250000, 250000] - restarts = [250000, 500000, 750000] - restart_weights = [1, 1, 1] + T_period = [80000, 80000, 80000, 80000] + restarts = [100000, 200000] + restart_weights = [.5, .25] - scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, + scheduler = CosineAnnealingLR_Restart(optimizer, T_period, warmup=100000, eta_min=.01, restarts=restarts, weights=restart_weights) ############################## diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py index 422965fe..44695c44 100644 --- a/codes/trainer/networks.py +++ b/codes/trainer/networks.py @@ -126,7 +126,8 @@ def define_G(opt, opt_net, scale=None): netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'], in_channels=3, use_input_norm=opt_net['use_input_norm']) elif which_model == 'resnet52': - netG = torchvision.models.resnet50(pretrained=opt_net['pretrained']) + from models.resnet_with_checkpointing import resnet50 + netG = resnet50(pretrained=opt_net['pretrained']) elif which_model == 'glean': from models.glean.glean import GleanGenerator netG = GleanGenerator(opt_net['nf'], opt_net['pretrained_stylegan']) diff --git a/codes/trainer/optimizers/larc.py b/codes/trainer/optimizers/larc.py new file mode 100644 index 00000000..a2e07e95 --- /dev/null +++ b/codes/trainer/optimizers/larc.py @@ -0,0 +1,110 @@ +import torch +from torch import nn +from torch.nn.parameter import Parameter + + +class LARC(object): + """ + :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, + in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive + local learning rate for each individual parameter. The algorithm is designed to improve + convergence of large batch training. + + See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. + + In practice it modifies the gradients of parameters as a proxy for modifying the learning rate + of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. + + ``` + model = ... + optim = torch.optim.Adam(model.parameters(), lr=...) + optim = LARC(optim) + ``` + + It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. + + ``` + model = ... + optim = torch.optim.Adam(model.parameters(), lr=...) + optim = LARC(optim) + optim = apex.fp16_utils.FP16_Optimizer(optim) + ``` + + Args: + optimizer: Pytorch optimizer to wrap and modify learning rate for. + trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 + clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. + eps: epsilon kludge to help with numerical stability while calculating adaptive_lr + """ + + def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8): + self.optim = optimizer + self.trust_coefficient = trust_coefficient + self.eps = eps + self.clip = clip + + def __getstate__(self): + return self.optim.__getstate__() + + def __setstate__(self, state): + self.optim.__setstate__(state) + + @property + def state(self): + return self.optim.state + + def __repr__(self): + return self.optim.__repr__() + + @property + def param_groups(self): + return self.optim.param_groups + + @param_groups.setter + def param_groups(self, value): + self.optim.param_groups = value + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict): + self.optim.load_state_dict(state_dict) + + def zero_grad(self): + self.optim.zero_grad() + + def add_param_group(self, param_group): + self.optim.add_param_group(param_group) + + def step(self): + with torch.no_grad(): + weight_decays = [] + for group in self.optim.param_groups: + # absorb weight decay control from optimizer + weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 + weight_decays.append(weight_decay) + group['weight_decay'] = 0 + for p in group['params']: + is_bn_or_bias = (hasattr(p, 'is_bn') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias) + if p.grad is None or is_bn_or_bias: + continue + param_norm = torch.norm(p.data) + grad_norm = torch.norm(p.grad.data) + + if param_norm != 0 and grad_norm != 0: + # calculate adaptive lr + weight decay + adaptive_lr = self.trust_coefficient * (param_norm) / ( + grad_norm + param_norm * weight_decay + self.eps) + + # clip learning rate for LARC + if self.clip: + # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` + adaptive_lr = min(adaptive_lr / group['lr'], 1) + + p.grad.data += weight_decay * p.data + p.grad.data *= adaptive_lr + + self.optim.step() + # return weight decay control to optimizer + for i, group in enumerate(self.optim.param_groups): + group['weight_decay'] = weight_decays[i] \ No newline at end of file diff --git a/codes/trainer/optimizers/sgd.py b/codes/trainer/optimizers/sgd.py new file mode 100644 index 00000000..f82bf33c --- /dev/null +++ b/codes/trainer/optimizers/sgd.py @@ -0,0 +1,72 @@ +import torch +from torch.optim import Optimizer + + +class SGDNoBiasMomentum(Optimizer): + r""" + Copy of pytorch implementation of SGD with a modification which turns off momentum for params marked + with `is_bn` or `is_bias`. + """ + + def __init__(self, params, lr, momentum=0, dampening=0, + weight_decay=0, nesterov=False): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + d_p = p.grad + if weight_decay != 0: + d_p = d_p.add(p, alpha=weight_decay) + # **this is the only modification over standard torch.optim.SGD: + is_bn_or_bias = (hasattr(p, 'is_bn') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias) + if not is_bn_or_bias and momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + + p.add_(d_p, alpha=-group['lr']) + + return loss diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index d2d39f65..5f09038b 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -66,6 +66,16 @@ class ConfigurableStep(Module): for net_name, net, opt_config in zip(training, nets, opt_configs): optim_params = [] for k, v in net.named_parameters(): # can optimize for a part of the model + # Make some inference about these parameters, which can be used by some optimizers to treat certain + # parameters differently. For example, it is considered good practice to not do weight decay on + # BN & bias parameters. TODO: process the module tree instead of the parameter tree to accomplish the + # same thing, but in a more effective way. + if k.endswith(".bias"): + v.is_bias = True + if k.endswith(".weight"): + v.is_weight = True + if ".bn" in k or '.batchnorm' in k or '.bnorm' in k: + v.is_bn = True if v.requires_grad: optim_params.append(v) else: @@ -76,9 +86,12 @@ class ConfigurableStep(Module): opt = torch.optim.Adam(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'], betas=(opt_config['beta1'], opt_config['beta2'])) - elif self.step_opt['optimizer'] == 'novograd': - opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'], - betas=(opt_config['beta1'], opt_config['beta2'])) + elif self.step_opt['optimizer'] == 'lars': + from trainer.optimizers.larc import LARC + from trainer.optimizers.sgd import SGDNoBiasMomentum + optSGD = SGDNoBiasMomentum(optim_params, lr=opt_config['lr'], momentum=opt_config['momentum'], + weight_decay=opt_config['weight_decay']) + opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient']) opt._config = opt_config # This is a bit seedy, but we will need these configs later. opt._config['network'] = net_name self.optimizers.append(opt)