APOLLO tweaks to make it work with deepspeed

This commit is contained in:
mrq 2024-12-13 23:03:52 -06:00
parent 64c67160a3
commit 09804ecc16
4 changed files with 67 additions and 92 deletions

View File

@ -520,7 +520,7 @@ class DeepSpeed:
use_compression_training: bool = False # cope use_compression_training: bool = False # cope
compression_bits: int = 8 # cope compression_bits: int = 8 # cope
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
optimizer: bool = True # use DeepSpeed optimizer wrapper
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently) amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
loss_scale_window: int = 100 loss_scale_window: int = 100

View File

@ -132,41 +132,17 @@ def load_engines(training=True, **model_kwargs):
params['d_coef'] = params['lr'] params['d_coef'] = params['lr']
params['lr'] = 1.0 params['lr'] = 1.0
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]: elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
"""
if backend == "deepspeed":
raise Exception("APOLLO currently does not play nicely with DeepSpeed.")
"""
optimizer_class = ml.Apollo optimizer_class = ml.Apollo
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini" is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
param_kwargs = {
params.update({
"rank": 1 if is_mini else 256, "rank": 1 if is_mini else 256,
"proj": "random", "proj": "random",
"scale_type": "tensor" if is_mini else "channel", "scale_type": "tensor" if is_mini else "channel",
"scale": 128 if is_mini else 1, "scale": 128 if is_mini else 1,
"update_proj_gap": 200, "update_proj_gap": 200,
"proj_type": "std", "proj_type": "std",
} })
# grab any extra configs from the YAML
param_kwargs.update(cfg.hyperparameters.optimizer_params)
# and blank it so it doesn't update the main optimizer kwargs
cfg.hyperparameters.optimizer_params = {}
"""
params["params"] = [{'params': params["params"]} | param_kwargs]
"""
target_params = []
target_modules_list = ["attn", "mlp"]
for module_name, module in model.named_modules():
if not (isinstance(module, torch.nn.Linear)):
continue
if not any(target_key in module_name for target_key in target_modules_list):
continue
target_params.append(module.weight)
param_ids = [id(p) for p in target_params]
regular_params = [p for p in model.parameters() if id(p) not in param_ids]
params["params"] = [{'params': regular_params}, {'params': target_params} | param_kwargs]
elif cfg.hyperparameters.optimizer.lower() == "adagrad": elif cfg.hyperparameters.optimizer.lower() == "adagrad":
optimizer_class = ml.Adagrad optimizer_class = ml.Adagrad
else: else:

View File

@ -406,10 +406,14 @@ class Engines(dict[str, Engine]):
if cfg.lora is not None: if cfg.lora is not None:
save_dir = cfg.ckpt_dir / cfg.lora.full_name save_dir = cfg.ckpt_dir / cfg.lora.full_name
engine.save_checkpoint(save_dir, tag=tag)
"""
try: try:
engine.save_checkpoint(save_dir, tag=tag) engine.save_checkpoint(save_dir, tag=tag)
except Exception as e: except Exception as e:
_logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}') _logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}')
"""
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None] # might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader(): if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
@ -515,11 +519,11 @@ class Engines(dict[str, Engine]):
start_time = time.time() start_time = time.time()
batch = to_device(batch, device) batch = to_device(batch, device)
n_ooms = torch.zeros([], device=device)
if not cfg.trainer.check_for_oom: if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch, teacher=teacher ) res = feeder( engine=engine, batch=batch, teacher=teacher )
else: else:
forward_ooms = torch.zeros([], device=device)
try: try:
res = feeder( engine=engine, batch=batch, teacher=teacher ) res = feeder( engine=engine, batch=batch, teacher=teacher )
except RuntimeError as e: except RuntimeError as e:
@ -529,12 +533,12 @@ class Engines(dict[str, Engine]):
self.save_checkpoint() self.save_checkpoint()
raise e raise e
n_ooms += 1 forward_ooms += 1
if world_size() > 1: if world_size() > 1:
all_reduce(n_ooms) all_reduce(forward_ooms)
if n_ooms.item() > 0: if forward_ooms.item() > 0:
continue continue
""" """
self.save_checkpoint() self.save_checkpoint()
@ -554,7 +558,7 @@ class Engines(dict[str, Engine]):
if not cfg.trainer.check_for_oom: if not cfg.trainer.check_for_oom:
engine.backward(loss) engine.backward(loss)
else: else:
# to-do: properly handle when one GPU throws an OOM because it just halts despite doing a gather/reduce backward_ooms = torch.zeros([], device=device)
try: try:
engine.backward(loss) engine.backward(loss)
except RuntimeError as e: except RuntimeError as e:
@ -564,12 +568,12 @@ class Engines(dict[str, Engine]):
self.save_checkpoint() self.save_checkpoint()
raise e raise e
n_ooms += 1 backward_ooms += 1
if world_size() > 1: if world_size() > 1:
all_reduce(n_ooms) all_reduce(backward_ooms)
if n_ooms.item() > 0: if backward_ooms.item() > 0:
self.save_checkpoint() self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!") raise RuntimeError("Out of memory during backwards pass!")

View File

@ -28,7 +28,7 @@ class GaLoreProjector:
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device) self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
if self.proj_type == 'std': if self.proj_type == 'std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0: if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
self.svd_count += 1 self.svd_count += 1
@ -39,7 +39,7 @@ class GaLoreProjector:
self.svd_count += 1 self.svd_count += 1
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == 'reverse_std': elif self.proj_type == 'reverse_std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0: if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
self.svd_count += 1 self.svd_count += 1
@ -70,12 +70,12 @@ class GaLoreProjector:
def project_back(self, low_rank_grad): def project_back(self, low_rank_grad):
if self.proj_type == 'std': if self.proj_type == 'std':
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: if low_rank_grad.dim() > 1 and low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
else: else:
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == 'reverse_std': elif self.proj_type == 'reverse_std':
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std if low_rank_grad.dim() > 1 and low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
else: else:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
@ -170,9 +170,8 @@ class GradientProjector:
self.seed = seed self.seed = seed
def project(self, full_rank_grad, iter): def project(self, full_rank_grad, iter):
if self.proj_type == "std": if self.proj_type == "std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0: if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix( self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right", seed=self.seed full_rank_grad, self.rank, type="right", seed=self.seed
@ -188,7 +187,7 @@ class GradientProjector:
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "reverse_std": elif self.proj_type == "reverse_std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0: if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix( self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left", seed=self.seed full_rank_grad, self.rank, type="left", seed=self.seed
@ -263,27 +262,6 @@ class GradientProjector:
raise ValueError("type should be left, right or full") raise ValueError("type should be left, right or full")
class Apollo(Optimizer): class Apollo(Optimizer):
"""
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
Regularization](https://arxiv.org/abs/1711.05101).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.001):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-06):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.0):
Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
"""
def __init__( def __init__(
self, self,
params: Iterable[nn.parameter.Parameter], params: Iterable[nn.parameter.Parameter],
@ -292,7 +270,13 @@ class Apollo(Optimizer):
eps: float = 1e-6, eps: float = 1e-6,
weight_decay: float = 0.0, weight_decay: float = 0.0,
correct_bias: bool = True, correct_bias: bool = True,
scale_front: bool = False,
rank: int = 256,
proj: str = "random",
scale_type: str = "channel",
scale: int = 1,
update_proj_gap: int = 200,
proj_type: str = "std",
): ):
if lr < 0.0: if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
@ -302,17 +286,30 @@ class Apollo(Optimizer):
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"correct_bias": correct_bias,
"rank": rank,
"proj": proj,
"scale_type": scale_type,
"scale": scale,
"update_proj_gap": update_proj_gap,
"proj_type": proj_type,
}
super().__init__(params, defaults) super().__init__(params, defaults)
self.scale_front = scale_front """
params_idx = 0 params_idx = 0
for group in self.param_groups: for group in self.param_groups:
for p in group["params"]: for p in group["params"]:
params_idx += 1 params_idx += 1
if p.requires_grad: if p.requires_grad:
self.state[p]["seed"] = params_idx self.state[p]["seed"] = params_idx
"""
@torch.no_grad() @torch.no_grad()
def step(self, closure: Callable = None): def step(self, closure: Callable = None):
@ -332,9 +329,9 @@ class Apollo(Optimizer):
params_idx += 1 params_idx += 1
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad grad = p.grad.data
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") raise RuntimeError("APOLLO does not support sparse gradients")
state = self.state[p] state = self.state[p]
@ -345,29 +342,32 @@ class Apollo(Optimizer):
state["seed"] = params_idx state["seed"] = params_idx
# GaLore Projection # GaLore Projection
if "rank" in group: if group["rank"] > 0:
if "projector" not in state: if "projector" not in state:
if group["proj"] == "random": if group["proj"] == "random":
state["projector"] = GradientProjector(group["rank"], state["projector"] = GradientProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"], update_proj_gap=group["update_proj_gap"],
alpha=group["scale"], alpha=group["scale"],
proj_type=group["proj_type"], proj_type=group["proj_type"],
seed=state["seed"]) seed=state["seed"]
)
elif group["proj"] == "svd": elif group["proj"] == "svd":
state["projector"] = GaLoreProjector(group["rank"], state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"], update_proj_gap=group["update_proj_gap"],
scale=group["scale"], scale=group["scale"],
proj_type=group["proj_type"]) proj_type=group["proj_type"]
)
grad = state["projector"].project(grad, state["step"]) grad = state["projector"].project(grad, state["step"])
# State initialization
if "exp_avg" not in state: if "exp_avg" not in state:
# Exponential moving average of gradient values # Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad) state["exp_avg"] = torch.zeros_like(grad).detach()
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad) state["exp_avg_sq"] = torch.zeros_like(grad).detach()
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
@ -389,9 +389,12 @@ class Apollo(Optimizer):
# compute norm gradient # compute norm gradient
norm_grad = exp_avg / denom norm_grad = exp_avg / denom
if "rank" in group: if group["rank"] > 0:
if group['scale_type'] == 'channel': if group['scale_type'] == 'channel':
if norm_grad.dim() > 1:
norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1 norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
else:
norm_dim = 0
scaling_factor = ( scaling_factor = (
torch.norm(norm_grad, dim=norm_dim) / torch.norm(norm_grad, dim=norm_dim) /
(torch.norm(grad, dim=norm_dim) + 1e-8) (torch.norm(grad, dim=norm_dim) + 1e-8)
@ -405,7 +408,7 @@ class Apollo(Optimizer):
(torch.norm(grad) + 1e-8) (torch.norm(grad) + 1e-8)
) )
scaling_grad = p.grad * scaling_factor scaling_grad = p.grad.data * scaling_factor
# Use Norm-Growth Limiter in Fira # Use Norm-Growth Limiter in Fira
if "scaling_grad" in state: if "scaling_grad" in state:
@ -422,17 +425,9 @@ class Apollo(Optimizer):
norm_grad = scaling_grad * np.sqrt(group["scale"]) norm_grad = scaling_grad * np.sqrt(group["scale"])
p.add_(norm_grad, alpha=-step_size) p.data.add_(norm_grad, alpha=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0: if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) p.data.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
return loss return loss