vall-e/vall_e/utils/ext/apollo.py

438 lines
16 KiB
Python
Raw Normal View History

# "borrowed" with love from https://github.com/MadsToftrup/Apollo-dev/blob/main/galore_torch/apollo.py
# to be replaced with the official implementation (https://github.com/zhuhanqing/APOLLO) maybe
import torch
import math
import numpy as np
from torch import nn
from torch.optim import Optimizer
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sequence, Union, Tuple
from transformers.utils.versions import require_version
class GaLoreProjector:
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'):
self.rank = rank
self.verbose = verbose
self.update_proj_gap = update_proj_gap
self.scale = scale
self.ortho_matrix = None
self.proj_type = proj_type
self.svd_count = 0
def project(self, full_rank_grad, iter):
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
if self.proj_type == 'std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
self.svd_count += 1
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
self.svd_count += 1
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == 'reverse_std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
self.svd_count += 1
low_rank_grad = torch.matmul(self.ortho_matrix.t(),full_rank_grad)
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
self.svd_count += 1
low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t())
elif self.proj_type == 'right':
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
self.svd_count += 1
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == 'left':
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
self.svd_count += 1
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == 'full':
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
self.svd_count += 1
low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
return low_rank_grad
def project_back(self, low_rank_grad):
if self.proj_type == 'std':
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
else:
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == 'reverse_std':
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
else:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == 'right':
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == 'left':
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == 'full':
full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
return full_rank_grad * self.scale
# svd decomposition
def get_orthogonal_matrix(self, weights, rank, type):
module_params = weights
if module_params.data.dtype != torch.float:
float_data = False
original_type = module_params.data.dtype
original_device = module_params.data.device
matrix = module_params.data.float()
else:
float_data = True
matrix = module_params.data
U, s, Vh = torch.linalg.svd(matrix, full_matrices = False)
#make the smaller matrix always to be orthogonal matrix
if type=='right':
A = U[:, :rank] @ torch.diag(s[:rank])
B = Vh[:rank, :]
if not float_data:
B = B.to(original_device).type(original_type)
return B
elif type=='left':
A = U[:, :rank]
B = torch.diag(s[:rank]) @ Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
return A
elif type=='full':
A = U[:, :rank]
B = Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
B = B.to(original_device).type(original_type)
return [A, B]
else:
raise ValueError('type should be left, right or full')
def stable_randn(
shape: Union[int, Sequence[int]],
seed: int,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = torch.float32,
):
if device is None:
device = torch.device("cpu")
generator = torch.Generator(device=device).manual_seed(seed)
rn = torch.randn(shape, generator=generator, device=generator.device, dtype=dtype)
return rn
def next_seed(seed: int, adv: int = 0xF):
"""
This is a naive helper function to generate a new seed from the given seed.
"""
generator = torch.Generator().manual_seed(seed)
return torch.randint(
0, torch.iinfo(torch.int64).max, (adv,), generator=generator, device=generator.device
).tolist()[-1]
def split_seed(seed: int):
generator = torch.Generator().manual_seed(seed)
return tuple(
torch.randint(0, torch.iinfo(torch.int64).max, (2,), generator=generator, device=generator.device).tolist()
)
class GradientProjector:
def __init__(
self, rank, update_proj_gap=200, alpha=1.0, proj_type="std", seed=0
):
# This is a lazy implementation as we store the projection matrix instead of re-generation every iteration
self.rank = rank
self.update_proj_gap = update_proj_gap
self.alpha = alpha
self.proj_type = proj_type
self.ortho_matrix = None
self.seed = seed
def project(self, full_rank_grad, iter):
if self.proj_type == "std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "reverse_std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "right":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "left":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "full":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="full", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = (
torch.matmul(self.ortho_matrix[0].t(), full_rank_grad)
@ self.ortho_matrix[1].t()
)
return low_rank_grad
# random low rank projection
def get_orthogonal_matrix(self, weights, rank, type, seed):
module_params = weights
if module_params.data.dtype != torch.float:
float_data = False
original_type = module_params.data.dtype
original_device = module_params.data.device
matrix = module_params.data.float()
else:
float_data = True
matrix = module_params.data
if type == "left":
proj = stable_randn(
(matrix.shape[0], rank), seed=seed, device=matrix.device, dtype=matrix.dtype
) / math.sqrt(rank)
if not float_data:
proj = proj.to(original_device).type(original_type)
return proj
elif type == "right":
proj = stable_randn(
(rank, matrix.shape[1]), seed=seed, device=matrix.device, dtype=matrix.dtype
) / math.sqrt(rank)
if not float_data:
proj = proj.to(original_device).type(original_type)
return proj
elif type == "full":
raise NotImplementedError("full rank projection is not implemented yet")
else:
raise ValueError("type should be left, right or full")
class Apollo(Optimizer):
"""
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
Regularization](https://arxiv.org/abs/1711.05101).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.001):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-06):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.0):
Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
"""
def __init__(
self,
params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
scale_front: bool = False,
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
super().__init__(params, defaults)
self.scale_front = scale_front
params_idx = 0
for group in self.param_groups:
for p in group["params"]:
params_idx += 1
if p.requires_grad:
self.state[p]["seed"] = params_idx
@torch.no_grad()
def step(self, closure: Callable = None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
params_idx = 0
for group in self.param_groups:
for p in group["params"]:
params_idx += 1
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
state = self.state[p]
if "step" not in state:
state["step"] = 0
if "seed" not in state:
state["seed"] = params_idx
# GaLore Projection
if "rank" in group:
if "projector" not in state:
if group["proj"] == "random":
state["projector"] = GradientProjector(group["rank"],
update_proj_gap=group["update_proj_gap"],
alpha=group["scale"],
proj_type=group["proj_type"],
seed=state["seed"])
elif group["proj"] == "svd":
state["projector"] = GaLoreProjector(group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"])
grad = state["projector"].project(grad, state["step"])
# State initialization
if "exp_avg" not in state:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
# compute norm gradient
norm_grad = exp_avg / denom
if "rank" in group:
if group['scale_type'] == 'channel':
norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
scaling_factor = (
torch.norm(norm_grad, dim=norm_dim) /
(torch.norm(grad, dim=norm_dim) + 1e-8)
)
if norm_dim == 1:
scaling_factor = scaling_factor.unsqueeze(1)
elif group['scale_type'] == 'tensor':
scaling_factor = (
torch.norm(norm_grad) /
(torch.norm(grad) + 1e-8)
)
scaling_grad = p.grad * scaling_factor
# Use Norm-Growth Limiter in Fira
if "scaling_grad" in state:
scaling_grad_norm = torch.norm(scaling_grad)
limiter = max(
scaling_grad_norm /
(state["scaling_grad"] + 1e-8),
1.01,
) / 1.01
scaling_grad = scaling_grad / limiter
state["scaling_grad"] = scaling_grad_norm / limiter
else:
state["scaling_grad"] = torch.norm(scaling_grad)
norm_grad = scaling_grad * np.sqrt(group["scale"])
p.add_(norm_grad, alpha=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
return loss