forked from mrq/DL-Art-School
227 lines
7.4 KiB
Python
227 lines
7.4 KiB
Python
|
"""
|
||
|
Helpers to train with 16-bit precision.
|
||
|
"""
|
||
|
|
||
|
import numpy as np
|
||
|
import torch as th
|
||
|
import torch.nn as nn
|
||
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||
|
|
||
|
INITIAL_LOG_LOSS_SCALE = 20.0
|
||
|
|
||
|
|
||
|
def convert_module_to_f16(l):
|
||
|
"""
|
||
|
Convert primitive modules to float16.
|
||
|
"""
|
||
|
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||
|
l.weight.data = l.weight.data.half()
|
||
|
if l.bias is not None:
|
||
|
l.bias.data = l.bias.data.half()
|
||
|
|
||
|
|
||
|
def convert_module_to_f32(l):
|
||
|
"""
|
||
|
Convert primitive modules to float32, undoing convert_module_to_f16().
|
||
|
"""
|
||
|
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||
|
l.weight.data = l.weight.data.float()
|
||
|
if l.bias is not None:
|
||
|
l.bias.data = l.bias.data.float()
|
||
|
|
||
|
|
||
|
def make_master_params(param_groups_and_shapes):
|
||
|
"""
|
||
|
Copy model parameters into a (differently-shaped) list of full-precision
|
||
|
parameters.
|
||
|
"""
|
||
|
master_params = []
|
||
|
for param_group, shape in param_groups_and_shapes:
|
||
|
master_param = nn.Parameter(
|
||
|
_flatten_dense_tensors(
|
||
|
[param.detach().float() for (_, param) in param_group]
|
||
|
).view(shape)
|
||
|
)
|
||
|
master_param.requires_grad = True
|
||
|
master_params.append(master_param)
|
||
|
return master_params
|
||
|
|
||
|
|
||
|
def model_grads_to_master_grads(param_groups_and_shapes, master_params):
|
||
|
"""
|
||
|
Copy the gradients from the model parameters into the master parameters
|
||
|
from make_master_params().
|
||
|
"""
|
||
|
for master_param, (param_group, shape) in zip(
|
||
|
master_params, param_groups_and_shapes
|
||
|
):
|
||
|
master_param.grad = _flatten_dense_tensors(
|
||
|
[param_grad_or_zeros(param) for (_, param) in param_group]
|
||
|
).view(shape)
|
||
|
|
||
|
|
||
|
def master_params_to_model_params(param_groups_and_shapes, master_params):
|
||
|
"""
|
||
|
Copy the master parameter data back into the model parameters.
|
||
|
"""
|
||
|
# Without copying to a list, if a generator is passed, this will
|
||
|
# silently not copy any parameters.
|
||
|
for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
|
||
|
for (_, param), unflat_master_param in zip(
|
||
|
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
||
|
):
|
||
|
param.detach().copy_(unflat_master_param)
|
||
|
|
||
|
|
||
|
def unflatten_master_params(param_group, master_param):
|
||
|
return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
|
||
|
|
||
|
|
||
|
def get_param_groups_and_shapes(named_model_params):
|
||
|
named_model_params = list(named_model_params)
|
||
|
scalar_vector_named_params = (
|
||
|
[(n, p) for (n, p) in named_model_params if p.ndim <= 1],
|
||
|
(-1),
|
||
|
)
|
||
|
matrix_named_params = (
|
||
|
[(n, p) for (n, p) in named_model_params if p.ndim > 1],
|
||
|
(1, -1),
|
||
|
)
|
||
|
return [scalar_vector_named_params, matrix_named_params]
|
||
|
|
||
|
|
||
|
def master_params_to_state_dict(
|
||
|
model, param_groups_and_shapes, master_params, use_fp16
|
||
|
):
|
||
|
if use_fp16:
|
||
|
state_dict = model.state_dict()
|
||
|
for master_param, (param_group, _) in zip(
|
||
|
master_params, param_groups_and_shapes
|
||
|
):
|
||
|
for (name, _), unflat_master_param in zip(
|
||
|
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
||
|
):
|
||
|
assert name in state_dict
|
||
|
state_dict[name] = unflat_master_param
|
||
|
else:
|
||
|
state_dict = model.state_dict()
|
||
|
for i, (name, _value) in enumerate(model.named_parameters()):
|
||
|
assert name in state_dict
|
||
|
state_dict[name] = master_params[i]
|
||
|
return state_dict
|
||
|
|
||
|
|
||
|
def state_dict_to_master_params(model, state_dict, use_fp16):
|
||
|
if use_fp16:
|
||
|
named_model_params = [
|
||
|
(name, state_dict[name]) for name, _ in model.named_parameters()
|
||
|
]
|
||
|
param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
|
||
|
master_params = make_master_params(param_groups_and_shapes)
|
||
|
else:
|
||
|
master_params = [state_dict[name] for name, _ in model.named_parameters()]
|
||
|
return master_params
|
||
|
|
||
|
|
||
|
def zero_master_grads(master_params):
|
||
|
for param in master_params:
|
||
|
param.grad = None
|
||
|
|
||
|
|
||
|
def zero_grad(model_params):
|
||
|
for param in model_params:
|
||
|
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
|
||
|
if param.grad is not None:
|
||
|
param.grad.detach_()
|
||
|