APOLLO tweaks to make it work with deepspeed
This commit is contained in:
parent
64c67160a3
commit
09804ecc16
|
@ -520,7 +520,7 @@ class DeepSpeed:
|
||||||
use_compression_training: bool = False # cope
|
use_compression_training: bool = False # cope
|
||||||
compression_bits: int = 8 # cope
|
compression_bits: int = 8 # cope
|
||||||
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
||||||
|
optimizer: bool = True # use DeepSpeed optimizer wrapper
|
||||||
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
|
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
|
||||||
|
|
||||||
loss_scale_window: int = 100
|
loss_scale_window: int = 100
|
||||||
|
|
|
@ -132,41 +132,17 @@ def load_engines(training=True, **model_kwargs):
|
||||||
params['d_coef'] = params['lr']
|
params['d_coef'] = params['lr']
|
||||||
params['lr'] = 1.0
|
params['lr'] = 1.0
|
||||||
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
||||||
"""
|
|
||||||
if backend == "deepspeed":
|
|
||||||
raise Exception("APOLLO currently does not play nicely with DeepSpeed.")
|
|
||||||
"""
|
|
||||||
|
|
||||||
optimizer_class = ml.Apollo
|
optimizer_class = ml.Apollo
|
||||||
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
||||||
param_kwargs = {
|
|
||||||
|
params.update({
|
||||||
"rank": 1 if is_mini else 256,
|
"rank": 1 if is_mini else 256,
|
||||||
"proj": "random",
|
"proj": "random",
|
||||||
"scale_type": "tensor" if is_mini else "channel",
|
"scale_type": "tensor" if is_mini else "channel",
|
||||||
"scale": 128 if is_mini else 1,
|
"scale": 128 if is_mini else 1,
|
||||||
"update_proj_gap": 200,
|
"update_proj_gap": 200,
|
||||||
"proj_type": "std",
|
"proj_type": "std",
|
||||||
}
|
})
|
||||||
# grab any extra configs from the YAML
|
|
||||||
param_kwargs.update(cfg.hyperparameters.optimizer_params)
|
|
||||||
# and blank it so it doesn't update the main optimizer kwargs
|
|
||||||
cfg.hyperparameters.optimizer_params = {}
|
|
||||||
|
|
||||||
"""
|
|
||||||
params["params"] = [{'params': params["params"]} | param_kwargs]
|
|
||||||
"""
|
|
||||||
target_params = []
|
|
||||||
target_modules_list = ["attn", "mlp"]
|
|
||||||
for module_name, module in model.named_modules():
|
|
||||||
if not (isinstance(module, torch.nn.Linear)):
|
|
||||||
continue
|
|
||||||
if not any(target_key in module_name for target_key in target_modules_list):
|
|
||||||
continue
|
|
||||||
target_params.append(module.weight)
|
|
||||||
|
|
||||||
param_ids = [id(p) for p in target_params]
|
|
||||||
regular_params = [p for p in model.parameters() if id(p) not in param_ids]
|
|
||||||
params["params"] = [{'params': regular_params}, {'params': target_params} | param_kwargs]
|
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||||
optimizer_class = ml.Adagrad
|
optimizer_class = ml.Adagrad
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -406,10 +406,14 @@ class Engines(dict[str, Engine]):
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||||
|
|
||||||
|
engine.save_checkpoint(save_dir, tag=tag)
|
||||||
|
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
engine.save_checkpoint(save_dir, tag=tag)
|
engine.save_checkpoint(save_dir, tag=tag)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}')
|
_logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}')
|
||||||
|
"""
|
||||||
|
|
||||||
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
|
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
|
||||||
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
|
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
|
||||||
|
@ -515,11 +519,11 @@ class Engines(dict[str, Engine]):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
batch = to_device(batch, device)
|
batch = to_device(batch, device)
|
||||||
n_ooms = torch.zeros([], device=device)
|
|
||||||
|
|
||||||
if not cfg.trainer.check_for_oom:
|
if not cfg.trainer.check_for_oom:
|
||||||
res = feeder( engine=engine, batch=batch, teacher=teacher )
|
res = feeder( engine=engine, batch=batch, teacher=teacher )
|
||||||
else:
|
else:
|
||||||
|
forward_ooms = torch.zeros([], device=device)
|
||||||
try:
|
try:
|
||||||
res = feeder( engine=engine, batch=batch, teacher=teacher )
|
res = feeder( engine=engine, batch=batch, teacher=teacher )
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
@ -529,12 +533,12 @@ class Engines(dict[str, Engine]):
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
n_ooms += 1
|
forward_ooms += 1
|
||||||
|
|
||||||
if world_size() > 1:
|
if world_size() > 1:
|
||||||
all_reduce(n_ooms)
|
all_reduce(forward_ooms)
|
||||||
|
|
||||||
if n_ooms.item() > 0:
|
if forward_ooms.item() > 0:
|
||||||
continue
|
continue
|
||||||
"""
|
"""
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
|
@ -554,7 +558,7 @@ class Engines(dict[str, Engine]):
|
||||||
if not cfg.trainer.check_for_oom:
|
if not cfg.trainer.check_for_oom:
|
||||||
engine.backward(loss)
|
engine.backward(loss)
|
||||||
else:
|
else:
|
||||||
# to-do: properly handle when one GPU throws an OOM because it just halts despite doing a gather/reduce
|
backward_ooms = torch.zeros([], device=device)
|
||||||
try:
|
try:
|
||||||
engine.backward(loss)
|
engine.backward(loss)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
@ -564,12 +568,12 @@ class Engines(dict[str, Engine]):
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
n_ooms += 1
|
backward_ooms += 1
|
||||||
|
|
||||||
if world_size() > 1:
|
if world_size() > 1:
|
||||||
all_reduce(n_ooms)
|
all_reduce(backward_ooms)
|
||||||
|
|
||||||
if n_ooms.item() > 0:
|
if backward_ooms.item() > 0:
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
raise RuntimeError("Out of memory during backwards pass!")
|
raise RuntimeError("Out of memory during backwards pass!")
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ class GaLoreProjector:
|
||||||
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
|
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
|
||||||
|
|
||||||
if self.proj_type == 'std':
|
if self.proj_type == 'std':
|
||||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
||||||
self.svd_count += 1
|
self.svd_count += 1
|
||||||
|
@ -39,7 +39,7 @@ class GaLoreProjector:
|
||||||
self.svd_count += 1
|
self.svd_count += 1
|
||||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||||
elif self.proj_type == 'reverse_std':
|
elif self.proj_type == 'reverse_std':
|
||||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
||||||
self.svd_count += 1
|
self.svd_count += 1
|
||||||
|
@ -70,12 +70,12 @@ class GaLoreProjector:
|
||||||
def project_back(self, low_rank_grad):
|
def project_back(self, low_rank_grad):
|
||||||
|
|
||||||
if self.proj_type == 'std':
|
if self.proj_type == 'std':
|
||||||
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
|
if low_rank_grad.dim() > 1 and low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
|
||||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||||
else:
|
else:
|
||||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||||
elif self.proj_type == 'reverse_std':
|
elif self.proj_type == 'reverse_std':
|
||||||
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
|
if low_rank_grad.dim() > 1 and low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
|
||||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||||
else:
|
else:
|
||||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||||
|
@ -170,9 +170,8 @@ class GradientProjector:
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def project(self, full_rank_grad, iter):
|
def project(self, full_rank_grad, iter):
|
||||||
|
|
||||||
if self.proj_type == "std":
|
if self.proj_type == "std":
|
||||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
self.ortho_matrix = self.get_orthogonal_matrix(
|
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||||
full_rank_grad, self.rank, type="right", seed=self.seed
|
full_rank_grad, self.rank, type="right", seed=self.seed
|
||||||
|
@ -188,7 +187,7 @@ class GradientProjector:
|
||||||
|
|
||||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||||
elif self.proj_type == "reverse_std":
|
elif self.proj_type == "reverse_std":
|
||||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
self.ortho_matrix = self.get_orthogonal_matrix(
|
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||||
full_rank_grad, self.rank, type="left", seed=self.seed
|
full_rank_grad, self.rank, type="left", seed=self.seed
|
||||||
|
@ -263,27 +262,6 @@ class GradientProjector:
|
||||||
raise ValueError("type should be left, right or full")
|
raise ValueError("type should be left, right or full")
|
||||||
|
|
||||||
class Apollo(Optimizer):
|
class Apollo(Optimizer):
|
||||||
"""
|
|
||||||
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
|
|
||||||
Regularization](https://arxiv.org/abs/1711.05101).
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
params (`Iterable[nn.parameter.Parameter]`):
|
|
||||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
|
||||||
lr (`float`, *optional*, defaults to 0.001):
|
|
||||||
The learning rate to use.
|
|
||||||
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
|
|
||||||
Adam's betas parameters (b1, b2).
|
|
||||||
eps (`float`, *optional*, defaults to 1e-06):
|
|
||||||
Adam's epsilon for numerical stability.
|
|
||||||
weight_decay (`float`, *optional*, defaults to 0.0):
|
|
||||||
Decoupled weight decay to apply.
|
|
||||||
correct_bias (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
|
||||||
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
|
|
||||||
A flag used to disable the deprecation warning (set to `True` to disable the warning).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: Iterable[nn.parameter.Parameter],
|
params: Iterable[nn.parameter.Parameter],
|
||||||
|
@ -292,7 +270,13 @@ class Apollo(Optimizer):
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
correct_bias: bool = True,
|
correct_bias: bool = True,
|
||||||
scale_front: bool = False,
|
|
||||||
|
rank: int = 256,
|
||||||
|
proj: str = "random",
|
||||||
|
scale_type: str = "channel",
|
||||||
|
scale: int = 1,
|
||||||
|
update_proj_gap: int = 200,
|
||||||
|
proj_type: str = "std",
|
||||||
):
|
):
|
||||||
if lr < 0.0:
|
if lr < 0.0:
|
||||||
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
||||||
|
@ -302,17 +286,30 @@ class Apollo(Optimizer):
|
||||||
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
||||||
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
|
defaults = {
|
||||||
|
"lr": lr,
|
||||||
|
"betas": betas,
|
||||||
|
"eps": eps,
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
"correct_bias": correct_bias,
|
||||||
|
|
||||||
|
"rank": rank,
|
||||||
|
"proj": proj,
|
||||||
|
"scale_type": scale_type,
|
||||||
|
"scale": scale,
|
||||||
|
"update_proj_gap": update_proj_gap,
|
||||||
|
"proj_type": proj_type,
|
||||||
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
self.scale_front = scale_front
|
"""
|
||||||
|
|
||||||
params_idx = 0
|
params_idx = 0
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
params_idx += 1
|
params_idx += 1
|
||||||
if p.requires_grad:
|
if p.requires_grad:
|
||||||
self.state[p]["seed"] = params_idx
|
self.state[p]["seed"] = params_idx
|
||||||
|
"""
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure: Callable = None):
|
def step(self, closure: Callable = None):
|
||||||
|
@ -332,9 +329,9 @@ class Apollo(Optimizer):
|
||||||
params_idx += 1
|
params_idx += 1
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad
|
grad = p.grad.data
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
raise RuntimeError("APOLLO does not support sparse gradients")
|
||||||
|
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
|
|
||||||
|
@ -345,29 +342,32 @@ class Apollo(Optimizer):
|
||||||
state["seed"] = params_idx
|
state["seed"] = params_idx
|
||||||
|
|
||||||
# GaLore Projection
|
# GaLore Projection
|
||||||
if "rank" in group:
|
if group["rank"] > 0:
|
||||||
if "projector" not in state:
|
if "projector" not in state:
|
||||||
if group["proj"] == "random":
|
if group["proj"] == "random":
|
||||||
state["projector"] = GradientProjector(group["rank"],
|
state["projector"] = GradientProjector(
|
||||||
|
group["rank"],
|
||||||
update_proj_gap=group["update_proj_gap"],
|
update_proj_gap=group["update_proj_gap"],
|
||||||
alpha=group["scale"],
|
alpha=group["scale"],
|
||||||
proj_type=group["proj_type"],
|
proj_type=group["proj_type"],
|
||||||
seed=state["seed"])
|
seed=state["seed"]
|
||||||
|
)
|
||||||
|
|
||||||
elif group["proj"] == "svd":
|
elif group["proj"] == "svd":
|
||||||
state["projector"] = GaLoreProjector(group["rank"],
|
state["projector"] = GaLoreProjector(
|
||||||
|
group["rank"],
|
||||||
update_proj_gap=group["update_proj_gap"],
|
update_proj_gap=group["update_proj_gap"],
|
||||||
scale=group["scale"],
|
scale=group["scale"],
|
||||||
proj_type=group["proj_type"])
|
proj_type=group["proj_type"]
|
||||||
|
)
|
||||||
|
|
||||||
grad = state["projector"].project(grad, state["step"])
|
grad = state["projector"].project(grad, state["step"])
|
||||||
|
|
||||||
# State initialization
|
|
||||||
if "exp_avg" not in state:
|
if "exp_avg" not in state:
|
||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient values
|
||||||
state["exp_avg"] = torch.zeros_like(grad)
|
state["exp_avg"] = torch.zeros_like(grad).detach()
|
||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
state["exp_avg_sq"] = torch.zeros_like(grad).detach()
|
||||||
|
|
||||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
|
@ -389,9 +389,12 @@ class Apollo(Optimizer):
|
||||||
# compute norm gradient
|
# compute norm gradient
|
||||||
norm_grad = exp_avg / denom
|
norm_grad = exp_avg / denom
|
||||||
|
|
||||||
if "rank" in group:
|
if group["rank"] > 0:
|
||||||
if group['scale_type'] == 'channel':
|
if group['scale_type'] == 'channel':
|
||||||
|
if norm_grad.dim() > 1:
|
||||||
norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
|
norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
|
||||||
|
else:
|
||||||
|
norm_dim = 0
|
||||||
scaling_factor = (
|
scaling_factor = (
|
||||||
torch.norm(norm_grad, dim=norm_dim) /
|
torch.norm(norm_grad, dim=norm_dim) /
|
||||||
(torch.norm(grad, dim=norm_dim) + 1e-8)
|
(torch.norm(grad, dim=norm_dim) + 1e-8)
|
||||||
|
@ -405,7 +408,7 @@ class Apollo(Optimizer):
|
||||||
(torch.norm(grad) + 1e-8)
|
(torch.norm(grad) + 1e-8)
|
||||||
)
|
)
|
||||||
|
|
||||||
scaling_grad = p.grad * scaling_factor
|
scaling_grad = p.grad.data * scaling_factor
|
||||||
|
|
||||||
# Use Norm-Growth Limiter in Fira
|
# Use Norm-Growth Limiter in Fira
|
||||||
if "scaling_grad" in state:
|
if "scaling_grad" in state:
|
||||||
|
@ -422,17 +425,9 @@ class Apollo(Optimizer):
|
||||||
|
|
||||||
norm_grad = scaling_grad * np.sqrt(group["scale"])
|
norm_grad = scaling_grad * np.sqrt(group["scale"])
|
||||||
|
|
||||||
p.add_(norm_grad, alpha=-step_size)
|
p.data.add_(norm_grad, alpha=-step_size)
|
||||||
|
|
||||||
# Just adding the square of the weights to the loss function is *not*
|
|
||||||
# the correct way of using L2 regularization/weight decay with Adam,
|
|
||||||
# since that will interact with the m and v parameters in strange ways.
|
|
||||||
#
|
|
||||||
# Instead we want to decay the weights in a manner that doesn't interact
|
|
||||||
# with the m/v parameters. This is equivalent to adding the square
|
|
||||||
# of the weights to the loss with plain (non-momentum) SGD.
|
|
||||||
# Add weight decay at the end (fixed version)
|
|
||||||
if group["weight_decay"] > 0.0:
|
if group["weight_decay"] > 0.0:
|
||||||
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
p.data.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||||
|
|
||||||
return loss
|
return loss
|
Loading…
Reference in New Issue
Block a user