diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 19a6c84..69552f9 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -240,12 +240,12 @@ def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_durat if model.backend == "vocos": x = model.codes_to_features(codes[0]) wav = model.decode(x, bandwidth_id=model.bandwidth_id) - return wav, cfg.sample_rate - + if model.backend == "encodec": x = [(codes.to(device), None)] - wav = model.decode(x) - return wav, cfg.sample_rate + wav = model.decode(x)[0] + + return wav, cfg.sample_rate @torch.inference_mode() def decode_batch(codes: list[Tensor], device="cuda", dtype=None): diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 7b30418..65672cc 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -145,9 +145,25 @@ def load_engines(training=True, **model_kwargs): }) elif cfg.hyperparameters.optimizer.lower() == "adagrad": optimizer_class = ml.Adagrad + elif cfg.hyperparameters.optimizer.lower() == "muon": + del params["params"] + optimizer_class = ml.Muon + + + params["muon_params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 and f'model.{name}' not in model.config.frozen_params ] + params["adamw_params"] = [ param for name, param in model.model.named_parameters() if param.ndim < 2 and f'model.{name}' not in model.config.frozen_params ] + params["adamw_params"] += [ param for name, param in model.named_parameters() if not name.startswith('model.') and name not in model.config.frozen_params ] + + if cfg.hyperparameters.optimizer_params is not None: + params["adamw_betas"] = cfg.hyperparameters.optimizer_params.pop("adamw_betas", (0.95, 0.95)) + params["adamw_eps"] = cfg.hyperparameters.optimizer_params.pop("adamw_eps", 1e-8) else: raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') + params.update(cfg.hyperparameters.optimizer_params) + optimizer = optimizer_class(**params) + + """ if cfg.hyperparameters.optimizer_params is not None: muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None) params.update(cfg.hyperparameters.optimizer_params) @@ -164,6 +180,7 @@ def load_engines(training=True, **model_kwargs): ]) else: optimizer = optimizer_class(**params) + """ if cfg.hyperparameters.scheduler.lower() == "schedulefree": if cfg.hyperparameters.optimizer.lower() == "adamw": diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 75156f9..dead515 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -171,7 +171,7 @@ class AR_NAR(Base): resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 # only apply stop token for RVQ level 0 - if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None): + if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7): # append stop tokens for AR if task not in text_task: resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) @@ -1434,7 +1434,9 @@ def example_usage(): scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None - params = model.parameters() + params = { + "params": model.parameters() + } if cfg.optimizations.dadaptation: # do not combine the two if scheduler == "schedulefree": @@ -1467,23 +1469,25 @@ def example_usage(): learning_rate = 0.01 optimizer = ml.Apollo - params = [{'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}] + params["params"] = [{'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}] + elif optimizer == "muon": + del params["params"] + optimizer = ml.Muon + + params["muon_params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ] + params["adamw_params"] = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] + params["adamw_params"] += [ param for name, param in model.named_parameters() if not name.startswith('model.') ] + + if cfg.hyperparameters.optimizer_params is not None: + params["adamw_betas"] = cfg.hyperparameters.optimizer_params.pop("adamw_betas", (0.95, 0.95)) + params["adamw_eps"] = cfg.hyperparameters.optimizer_params.pop("adamw_eps", 1e-8) else: raise ValueError(f"Unrecognized optimizer: {optimizer}") _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}") - - muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None) if cfg.hyperparameters.optimizer_params is not None else None - if muon_params is not None: - muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ] - adam_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] + [ param for name, param in model.named_parameters() if not name.startswith('model.') ] - - optimizer = ml.Optimizers([ - ml.Muon(**muon_params), - optimizer(adam_params, lr=learning_rate) - ]) - else: - optimizer = optimizer(params, lr=learning_rate) + + params["lr"] = learning_rate + optimizer = optimizer(**params) if scheduler == "schedulefree": if isinstance(optimizer, ml.AdamW): diff --git a/vall_e/models/base.py b/vall_e/models/base.py index ca1aa54..c680a3d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1503,7 +1503,7 @@ class Base(nn.Module): return None, None # shift if causal - if causal: + if causal or self.version >= 7: l = self.causal_size logit = logit[..., :-l, :] # shift the target so that token n... sequence = sequence[..., l:] # ...predicts token n + 1 diff --git a/vall_e/utils/ext/muon.py b/vall_e/utils/ext/muon.py new file mode 100644 index 0000000..cee740b --- /dev/null +++ b/vall_e/utils/ext/muon.py @@ -0,0 +1,200 @@ +# From https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py +# because it combines both param types and makes life easier with DeepSpeed + +import os +import math +import torch +import torch.distributed as dist + +@torch.compile +def zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = ( + b * A + c * A @ A + ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + lr=1e-3, + wd=0.1, + muon_params=None, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_params=None, + adamw_betas=(0.95, 0.95), + adamw_eps=1e-8, + ): + + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + # Sort parameters into those for which we will use Muon, and those for which we will not + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + # import pdb; pdb.set_trace() + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + # generate weight updates in distributed fashion + for p in params: + # sanity check + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group['lr'] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss \ No newline at end of file diff --git a/vall_e/utils/ml.py b/vall_e/utils/ml.py index fa4502f..1b05feb 100755 --- a/vall_e/utils/ml.py +++ b/vall_e/utils/ml.py @@ -130,7 +130,7 @@ except Exception as e: pass try: - from muon import Muon as Muon + from .ext.muon import Muon except Exception as e: _logger.warning(f'Error while importing Muon: {str(e)}') pass