APOLLO tweaks to make it work with deepspeed
This commit is contained in:
parent
64c67160a3
commit
09804ecc16
|
@ -520,7 +520,7 @@ class DeepSpeed:
|
|||
use_compression_training: bool = False # cope
|
||||
compression_bits: int = 8 # cope
|
||||
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
||||
|
||||
optimizer: bool = True # use DeepSpeed optimizer wrapper
|
||||
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
|
||||
|
||||
loss_scale_window: int = 100
|
||||
|
|
|
@ -132,41 +132,17 @@ def load_engines(training=True, **model_kwargs):
|
|||
params['d_coef'] = params['lr']
|
||||
params['lr'] = 1.0
|
||||
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
||||
"""
|
||||
if backend == "deepspeed":
|
||||
raise Exception("APOLLO currently does not play nicely with DeepSpeed.")
|
||||
"""
|
||||
|
||||
optimizer_class = ml.Apollo
|
||||
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
||||
param_kwargs = {
|
||||
|
||||
params.update({
|
||||
"rank": 1 if is_mini else 256,
|
||||
"proj": "random",
|
||||
"scale_type": "tensor" if is_mini else "channel",
|
||||
"scale": 128 if is_mini else 1,
|
||||
"update_proj_gap": 200,
|
||||
"proj_type": "std",
|
||||
}
|
||||
# grab any extra configs from the YAML
|
||||
param_kwargs.update(cfg.hyperparameters.optimizer_params)
|
||||
# and blank it so it doesn't update the main optimizer kwargs
|
||||
cfg.hyperparameters.optimizer_params = {}
|
||||
|
||||
"""
|
||||
params["params"] = [{'params': params["params"]} | param_kwargs]
|
||||
"""
|
||||
target_params = []
|
||||
target_modules_list = ["attn", "mlp"]
|
||||
for module_name, module in model.named_modules():
|
||||
if not (isinstance(module, torch.nn.Linear)):
|
||||
continue
|
||||
if not any(target_key in module_name for target_key in target_modules_list):
|
||||
continue
|
||||
target_params.append(module.weight)
|
||||
|
||||
param_ids = [id(p) for p in target_params]
|
||||
regular_params = [p for p in model.parameters() if id(p) not in param_ids]
|
||||
params["params"] = [{'params': regular_params}, {'params': target_params} | param_kwargs]
|
||||
})
|
||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||
optimizer_class = ml.Adagrad
|
||||
else:
|
||||
|
|
|
@ -406,10 +406,14 @@ class Engines(dict[str, Engine]):
|
|||
if cfg.lora is not None:
|
||||
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||
|
||||
engine.save_checkpoint(save_dir, tag=tag)
|
||||
|
||||
"""
|
||||
try:
|
||||
engine.save_checkpoint(save_dir, tag=tag)
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}')
|
||||
"""
|
||||
|
||||
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
|
||||
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
|
||||
|
@ -515,11 +519,11 @@ class Engines(dict[str, Engine]):
|
|||
start_time = time.time()
|
||||
|
||||
batch = to_device(batch, device)
|
||||
n_ooms = torch.zeros([], device=device)
|
||||
|
||||
if not cfg.trainer.check_for_oom:
|
||||
res = feeder( engine=engine, batch=batch, teacher=teacher )
|
||||
else:
|
||||
forward_ooms = torch.zeros([], device=device)
|
||||
try:
|
||||
res = feeder( engine=engine, batch=batch, teacher=teacher )
|
||||
except RuntimeError as e:
|
||||
|
@ -529,12 +533,12 @@ class Engines(dict[str, Engine]):
|
|||
self.save_checkpoint()
|
||||
raise e
|
||||
|
||||
n_ooms += 1
|
||||
forward_ooms += 1
|
||||
|
||||
if world_size() > 1:
|
||||
all_reduce(n_ooms)
|
||||
all_reduce(forward_ooms)
|
||||
|
||||
if n_ooms.item() > 0:
|
||||
if forward_ooms.item() > 0:
|
||||
continue
|
||||
"""
|
||||
self.save_checkpoint()
|
||||
|
@ -554,7 +558,7 @@ class Engines(dict[str, Engine]):
|
|||
if not cfg.trainer.check_for_oom:
|
||||
engine.backward(loss)
|
||||
else:
|
||||
# to-do: properly handle when one GPU throws an OOM because it just halts despite doing a gather/reduce
|
||||
backward_ooms = torch.zeros([], device=device)
|
||||
try:
|
||||
engine.backward(loss)
|
||||
except RuntimeError as e:
|
||||
|
@ -564,12 +568,12 @@ class Engines(dict[str, Engine]):
|
|||
self.save_checkpoint()
|
||||
raise e
|
||||
|
||||
n_ooms += 1
|
||||
backward_ooms += 1
|
||||
|
||||
if world_size() > 1:
|
||||
all_reduce(n_ooms)
|
||||
all_reduce(backward_ooms)
|
||||
|
||||
if n_ooms.item() > 0:
|
||||
if backward_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
raise RuntimeError("Out of memory during backwards pass!")
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ class GaLoreProjector:
|
|||
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
|
||||
|
||||
if self.proj_type == 'std':
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
||||
self.svd_count += 1
|
||||
|
@ -39,7 +39,7 @@ class GaLoreProjector:
|
|||
self.svd_count += 1
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
elif self.proj_type == 'reverse_std':
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
||||
self.svd_count += 1
|
||||
|
@ -70,12 +70,12 @@ class GaLoreProjector:
|
|||
def project_back(self, low_rank_grad):
|
||||
|
||||
if self.proj_type == 'std':
|
||||
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
|
||||
if low_rank_grad.dim() > 1 and low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
|
||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||
else:
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||
elif self.proj_type == 'reverse_std':
|
||||
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
|
||||
if low_rank_grad.dim() > 1 and low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||
else:
|
||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||
|
@ -170,9 +170,8 @@ class GradientProjector:
|
|||
self.seed = seed
|
||||
|
||||
def project(self, full_rank_grad, iter):
|
||||
|
||||
if self.proj_type == "std":
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||
full_rank_grad, self.rank, type="right", seed=self.seed
|
||||
|
@ -188,7 +187,7 @@ class GradientProjector:
|
|||
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
elif self.proj_type == "reverse_std":
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||
full_rank_grad, self.rank, type="left", seed=self.seed
|
||||
|
@ -263,27 +262,6 @@ class GradientProjector:
|
|||
raise ValueError("type should be left, right or full")
|
||||
|
||||
class Apollo(Optimizer):
|
||||
"""
|
||||
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
|
||||
Regularization](https://arxiv.org/abs/1711.05101).
|
||||
|
||||
Parameters:
|
||||
params (`Iterable[nn.parameter.Parameter]`):
|
||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||
lr (`float`, *optional*, defaults to 0.001):
|
||||
The learning rate to use.
|
||||
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
|
||||
Adam's betas parameters (b1, b2).
|
||||
eps (`float`, *optional*, defaults to 1e-06):
|
||||
Adam's epsilon for numerical stability.
|
||||
weight_decay (`float`, *optional*, defaults to 0.0):
|
||||
Decoupled weight decay to apply.
|
||||
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
||||
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
|
||||
A flag used to disable the deprecation warning (set to `True` to disable the warning).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Iterable[nn.parameter.Parameter],
|
||||
|
@ -292,7 +270,13 @@ class Apollo(Optimizer):
|
|||
eps: float = 1e-6,
|
||||
weight_decay: float = 0.0,
|
||||
correct_bias: bool = True,
|
||||
scale_front: bool = False,
|
||||
|
||||
rank: int = 256,
|
||||
proj: str = "random",
|
||||
scale_type: str = "channel",
|
||||
scale: int = 1,
|
||||
update_proj_gap: int = 200,
|
||||
proj_type: str = "std",
|
||||
):
|
||||
if lr < 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
||||
|
@ -302,17 +286,30 @@ class Apollo(Optimizer):
|
|||
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
||||
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"betas": betas,
|
||||
"eps": eps,
|
||||
"weight_decay": weight_decay,
|
||||
"correct_bias": correct_bias,
|
||||
|
||||
"rank": rank,
|
||||
"proj": proj,
|
||||
"scale_type": scale_type,
|
||||
"scale": scale,
|
||||
"update_proj_gap": update_proj_gap,
|
||||
"proj_type": proj_type,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
self.scale_front = scale_front
|
||||
|
||||
"""
|
||||
params_idx = 0
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
params_idx += 1
|
||||
if p.requires_grad:
|
||||
self.state[p]["seed"] = params_idx
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure: Callable = None):
|
||||
|
@ -332,42 +329,45 @@ class Apollo(Optimizer):
|
|||
params_idx += 1
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
||||
raise RuntimeError("APOLLO does not support sparse gradients")
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
|
||||
if "seed" not in state:
|
||||
state["seed"] = params_idx
|
||||
|
||||
# GaLore Projection
|
||||
if "rank" in group:
|
||||
if group["rank"] > 0:
|
||||
if "projector" not in state:
|
||||
if group["proj"] == "random":
|
||||
state["projector"] = GradientProjector(group["rank"],
|
||||
state["projector"] = GradientProjector(
|
||||
group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
alpha=group["scale"],
|
||||
proj_type=group["proj_type"],
|
||||
seed=state["seed"])
|
||||
seed=state["seed"]
|
||||
)
|
||||
|
||||
elif group["proj"] == "svd":
|
||||
state["projector"] = GaLoreProjector(group["rank"],
|
||||
state["projector"] = GaLoreProjector(
|
||||
group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
scale=group["scale"],
|
||||
proj_type=group["proj_type"])
|
||||
proj_type=group["proj_type"]
|
||||
)
|
||||
|
||||
grad = state["projector"].project(grad, state["step"])
|
||||
|
||||
# State initialization
|
||||
if "exp_avg" not in state:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
state["exp_avg"] = torch.zeros_like(grad).detach()
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad).detach()
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
@ -389,9 +389,12 @@ class Apollo(Optimizer):
|
|||
# compute norm gradient
|
||||
norm_grad = exp_avg / denom
|
||||
|
||||
if "rank" in group:
|
||||
if group["rank"] > 0:
|
||||
if group['scale_type'] == 'channel':
|
||||
norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
|
||||
if norm_grad.dim() > 1:
|
||||
norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
|
||||
else:
|
||||
norm_dim = 0
|
||||
scaling_factor = (
|
||||
torch.norm(norm_grad, dim=norm_dim) /
|
||||
(torch.norm(grad, dim=norm_dim) + 1e-8)
|
||||
|
@ -405,7 +408,7 @@ class Apollo(Optimizer):
|
|||
(torch.norm(grad) + 1e-8)
|
||||
)
|
||||
|
||||
scaling_grad = p.grad * scaling_factor
|
||||
scaling_grad = p.grad.data * scaling_factor
|
||||
|
||||
# Use Norm-Growth Limiter in Fira
|
||||
if "scaling_grad" in state:
|
||||
|
@ -422,17 +425,9 @@ class Apollo(Optimizer):
|
|||
|
||||
norm_grad = scaling_grad * np.sqrt(group["scale"])
|
||||
|
||||
p.add_(norm_grad, alpha=-step_size)
|
||||
p.data.add_(norm_grad, alpha=-step_size)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group["weight_decay"] > 0.0:
|
||||
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
p.data.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
|
||||
return loss
|
Loading…
Reference in New Issue
Block a user