APOLLO tweaks to make it work with deepspeed

This commit is contained in:
mrq 2024-12-13 23:03:52 -06:00
parent 64c67160a3
commit 09804ecc16
4 changed files with 67 additions and 92 deletions

View File

@ -520,7 +520,7 @@ class DeepSpeed:
use_compression_training: bool = False # cope
compression_bits: int = 8 # cope
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
optimizer: bool = True # use DeepSpeed optimizer wrapper
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
loss_scale_window: int = 100

View File

@ -132,41 +132,17 @@ def load_engines(training=True, **model_kwargs):
params['d_coef'] = params['lr']
params['lr'] = 1.0
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
"""
if backend == "deepspeed":
raise Exception("APOLLO currently does not play nicely with DeepSpeed.")
"""
optimizer_class = ml.Apollo
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
param_kwargs = {
params.update({
"rank": 1 if is_mini else 256,
"proj": "random",
"scale_type": "tensor" if is_mini else "channel",
"scale": 128 if is_mini else 1,
"update_proj_gap": 200,
"proj_type": "std",
}
# grab any extra configs from the YAML
param_kwargs.update(cfg.hyperparameters.optimizer_params)
# and blank it so it doesn't update the main optimizer kwargs
cfg.hyperparameters.optimizer_params = {}
"""
params["params"] = [{'params': params["params"]} | param_kwargs]
"""
target_params = []
target_modules_list = ["attn", "mlp"]
for module_name, module in model.named_modules():
if not (isinstance(module, torch.nn.Linear)):
continue
if not any(target_key in module_name for target_key in target_modules_list):
continue
target_params.append(module.weight)
param_ids = [id(p) for p in target_params]
regular_params = [p for p in model.parameters() if id(p) not in param_ids]
params["params"] = [{'params': regular_params}, {'params': target_params} | param_kwargs]
})
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
optimizer_class = ml.Adagrad
else:

View File

@ -406,10 +406,14 @@ class Engines(dict[str, Engine]):
if cfg.lora is not None:
save_dir = cfg.ckpt_dir / cfg.lora.full_name
engine.save_checkpoint(save_dir, tag=tag)
"""
try:
engine.save_checkpoint(save_dir, tag=tag)
except Exception as e:
_logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}')
"""
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
@ -515,11 +519,11 @@ class Engines(dict[str, Engine]):
start_time = time.time()
batch = to_device(batch, device)
n_ooms = torch.zeros([], device=device)
if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch, teacher=teacher )
else:
forward_ooms = torch.zeros([], device=device)
try:
res = feeder( engine=engine, batch=batch, teacher=teacher )
except RuntimeError as e:
@ -529,12 +533,12 @@ class Engines(dict[str, Engine]):
self.save_checkpoint()
raise e
n_ooms += 1
forward_ooms += 1
if world_size() > 1:
all_reduce(n_ooms)
all_reduce(forward_ooms)
if n_ooms.item() > 0:
if forward_ooms.item() > 0:
continue
"""
self.save_checkpoint()
@ -554,7 +558,7 @@ class Engines(dict[str, Engine]):
if not cfg.trainer.check_for_oom:
engine.backward(loss)
else:
# to-do: properly handle when one GPU throws an OOM because it just halts despite doing a gather/reduce
backward_ooms = torch.zeros([], device=device)
try:
engine.backward(loss)
except RuntimeError as e:
@ -564,12 +568,12 @@ class Engines(dict[str, Engine]):
self.save_checkpoint()
raise e
n_ooms += 1
backward_ooms += 1
if world_size() > 1:
all_reduce(n_ooms)
all_reduce(backward_ooms)
if n_ooms.item() > 0:
if backward_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")

View File

@ -28,7 +28,7 @@ class GaLoreProjector:
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
if self.proj_type == 'std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
self.svd_count += 1
@ -39,7 +39,7 @@ class GaLoreProjector:
self.svd_count += 1
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == 'reverse_std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
self.svd_count += 1
@ -70,12 +70,12 @@ class GaLoreProjector:
def project_back(self, low_rank_grad):
if self.proj_type == 'std':
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
if low_rank_grad.dim() > 1 and low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
else:
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == 'reverse_std':
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
if low_rank_grad.dim() > 1 and low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
else:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
@ -170,9 +170,8 @@ class GradientProjector:
self.seed = seed
def project(self, full_rank_grad, iter):
if self.proj_type == "std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right", seed=self.seed
@ -188,7 +187,7 @@ class GradientProjector:
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "reverse_std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left", seed=self.seed
@ -263,27 +262,6 @@ class GradientProjector:
raise ValueError("type should be left, right or full")
class Apollo(Optimizer):
"""
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
Regularization](https://arxiv.org/abs/1711.05101).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.001):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-06):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.0):
Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
"""
def __init__(
self,
params: Iterable[nn.parameter.Parameter],
@ -292,7 +270,13 @@ class Apollo(Optimizer):
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
scale_front: bool = False,
rank: int = 256,
proj: str = "random",
scale_type: str = "channel",
scale: int = 1,
update_proj_gap: int = 200,
proj_type: str = "std",
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
@ -302,17 +286,30 @@ class Apollo(Optimizer):
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"correct_bias": correct_bias,
"rank": rank,
"proj": proj,
"scale_type": scale_type,
"scale": scale,
"update_proj_gap": update_proj_gap,
"proj_type": proj_type,
}
super().__init__(params, defaults)
self.scale_front = scale_front
"""
params_idx = 0
for group in self.param_groups:
for p in group["params"]:
params_idx += 1
if p.requires_grad:
self.state[p]["seed"] = params_idx
"""
@torch.no_grad()
def step(self, closure: Callable = None):
@ -332,9 +329,9 @@ class Apollo(Optimizer):
params_idx += 1
if p.grad is None:
continue
grad = p.grad
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
raise RuntimeError("APOLLO does not support sparse gradients")
state = self.state[p]
@ -345,29 +342,32 @@ class Apollo(Optimizer):
state["seed"] = params_idx
# GaLore Projection
if "rank" in group:
if group["rank"] > 0:
if "projector" not in state:
if group["proj"] == "random":
state["projector"] = GradientProjector(group["rank"],
state["projector"] = GradientProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
alpha=group["scale"],
proj_type=group["proj_type"],
seed=state["seed"])
seed=state["seed"]
)
elif group["proj"] == "svd":
state["projector"] = GaLoreProjector(group["rank"],
state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"])
proj_type=group["proj_type"]
)
grad = state["projector"].project(grad, state["step"])
# State initialization
if "exp_avg" not in state:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
state["exp_avg"] = torch.zeros_like(grad).detach()
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad)
state["exp_avg_sq"] = torch.zeros_like(grad).detach()
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
@ -389,9 +389,12 @@ class Apollo(Optimizer):
# compute norm gradient
norm_grad = exp_avg / denom
if "rank" in group:
if group["rank"] > 0:
if group['scale_type'] == 'channel':
if norm_grad.dim() > 1:
norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
else:
norm_dim = 0
scaling_factor = (
torch.norm(norm_grad, dim=norm_dim) /
(torch.norm(grad, dim=norm_dim) + 1e-8)
@ -405,7 +408,7 @@ class Apollo(Optimizer):
(torch.norm(grad) + 1e-8)
)
scaling_grad = p.grad * scaling_factor
scaling_grad = p.grad.data * scaling_factor
# Use Norm-Growth Limiter in Fira
if "scaling_grad" in state:
@ -422,17 +425,9 @@ class Apollo(Optimizer):
norm_grad = scaling_grad * np.sqrt(group["scale"])
p.add_(norm_grad, alpha=-step_size)
p.data.add_(norm_grad, alpha=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
p.data.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
return loss