diff --git a/docs/train.md b/docs/train.md index df562cc..c033f52 100644 --- a/docs/train.md +++ b/docs/train.md @@ -19,10 +19,9 @@ A training paradigm that works for me is: * additional training for sampling per speaker, to better help diversify how well it can perform for a range of speakers, rather than just speaking itself * I don't think this is crucial, but speaker-based sampling seems to be a placebo if anything. -Training under `float16` should be fairly simple, but care is required to keep the loss scaling factor above 8K, and probably even 16K. -* At the very least for pre-trained models, low enough loss scales will irreparably fry the model, and no amount of training afterwards seems to "fix" it. -* The current DeepSpeed configuration should keep the loss scale capped to 32K; normal training does not seem to have the loss scale ever want to dip below this at least. -* Training under `bfloat16` does not have to worry about this as there's no need for loss scaling, but I feel the model performs better when trained under `float16`+AMP rather than `bfloat16` (with or without AMP). +Training under `float16` (+AMP) should be fairly simple, but it's practically required to use the `deepspeed` backend. +* This is because `deepspeed` will automatically wrap the optimizer to handle training under `float16`, while the `local` backend does not do this. Training will *not* converge. +* Training under `bfloat16` does not have to worry about this. When training from scratch, maybe 30% of the time spent training is getting coherent speech, with a loose following of the prompt. The remaining bulk of the work is getting the model to closely-er resemble the input prompt. * an accuracy of at least 50% seems to be where coherent speech emerges. diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index d745864..955f50a 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -146,42 +146,22 @@ def load_engines(training=True, **model_kwargs): elif cfg.hyperparameters.optimizer.lower() == "adagrad": optimizer_class = ml.Adagrad elif cfg.hyperparameters.optimizer.lower() == "muon": - del params["params"] - optimizer_class = ml.Muon - + optimizer = ml.Muon - params["muon_params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 and f'model.{name}' not in model.config.frozen_params ] - params["adamw_params"] = [ param for name, param in model.model.named_parameters() if param.ndim < 2 and f'model.{name}' not in model.config.frozen_params ] - params["adamw_params"] += [ param for name, param in model.named_parameters() if not name.startswith('model.') and name not in model.config.frozen_params ] + muon_params = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ] + adamw_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] + adamw_params += [ param for name, param in model.named_parameters() if not name.startswith('model.') ] - if cfg.hyperparameters.optimizer_params is not None: - params["adamw_betas"] = cfg.hyperparameters.optimizer_params.pop("adamw_betas", (0.95, 0.95)) - params["adamw_eps"] = cfg.hyperparameters.optimizer_params.pop("adamw_eps", 1e-8) + params["params"] = [ + { "params": muon_params, "muon": True }, + { "params": adamw_params, "muon": False, "betas": (0.95, 0.95), "eps": 1e-8 }, + ] else: raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') params.update(cfg.hyperparameters.optimizer_params) optimizer = optimizer_class(**params) - """ - if cfg.hyperparameters.optimizer_params is not None: - muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None) - params.update(cfg.hyperparameters.optimizer_params) - - if muon_params is not None: - muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 and f'model.{name}' not in model.config.frozen_params ] - - params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim < 2 and f'model.{name}' not in model.config.frozen_params ] - params["params"] += [ param for name, param in model.named_parameters() if not name.startswith('model.') and name not in model.config.frozen_params ] - - optimizer = ml.Optimizers([ - ml.Muon(**muon_params), - optimizer_class(**params), - ]) - else: - optimizer = optimizer_class(**params) - """ - if cfg.hyperparameters.scheduler.lower() == "schedulefree": if cfg.hyperparameters.optimizer.lower() == "adamw": scheduler_class = ml.schedulefree.AdamWScheduleFree @@ -233,81 +213,9 @@ def load_engines(training=True, **model_kwargs): for k in erase: del state[k] - # converts an AR+NAR model into an AR+NAR-len model - """ - if True: - # move STT one over - state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone() - state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone() - # copy from AR:0:0 classifier - if True: - state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone() - state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone() - # copy from AR:0:0 embeddings - state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone() - # remove - else: - if 'classifiers.proj.8.weight' in state: - del state['classifiers.proj.8.weight'] - if 'classifiers.proj.8.bias' in state: - del state['classifiers.proj.8.bias'] - if 'resps_emb.embeddings.8.weight' in state: - del state['resps_emb.embeddings.8.weight'] - """ - # resize modules if I'm doing experiments and can't be assed to manually trim things if cfg.trainer.resize_modules: - uses_stop_token = 1 if ("ar" in model.capabilities or "len" in model.capabilities) > 0 else 0 - keys = [ - ("text_emb.weight", model.config.text_tokens ), - ("tasks_emb.weight", model.config.tasks ), - ("langs_emb.weight", model.config.langs ), - ("rvq_l_emb.weight", model.config.resp_levels ), - ("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ), - ("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ), - ("classifiers.proj.0.weight", model.config.audio_tokens + uses_stop_token ), - ("classifiers.proj.0.bias", model.config.audio_tokens + uses_stop_token ), - ("classifier.weight", model.n_vocab ), - ("classifier.bias", model.n_vocab ), - ] - - last_embedding_keys = {} - - # correcting an oversight - """ - if model.config.experimental.split_classifiers and "len" in model.capabilities: - len_idx, nar_0_idx = model.classifiers.indices(["len", "NAR:0:0"]) - keys.append((f"classifiers.proj.{len_idx}.weight", 11)) - keys.append((f"classifiers.proj.{len_idx}.bias", 11)) - - keys.append((f"classifiers.proj.{nar_0_idx}.weight", model.config.audio_tokens)) - keys.append((f"classifiers.proj.{nar_0_idx}.bias", model.config.audio_tokens)) - """ - - # correcting an oversight - """ - if True: - keys.append((f"classifiers.proj.0.weight", model.config.audio_tokens+1)) - for i in range(1,9): - keys.append((f"classifiers.proj.{i}.weight", model.config.audio_tokens)) - - keys.append((f"resps_emb.embeddings.0.weight", model.config.audio_tokens+1)) - keys.append((f"resps_emb.embeddings.8.weight", model.config.audio_tokens+1)) - - for i in range(1,8): - keys.append((f"resps_emb.embeddings.{i}.weight", model.config.audio_tokens)) - - for i in range(8): - keys.append((f"proms_emb.embeddings.{i}.weight", model.config.audio_tokens)) - - last_embedding_keys = { - "classifiers.proj.0.weight": state["classifiers.proj.0.weight"][-1].clone().detach(), - "resps_emb.embeddings.0.weight": state["resps_emb.embeddings.0.weight"][-1].clone().detach(), - "resps_emb.embeddings.8.weight": state["resps_emb.embeddings.8.weight"][-1].clone().detach(), - } - """ - - + keys = [] for k, tokens in keys: if k not in state: continue @@ -316,50 +224,6 @@ def load_engines(training=True, **model_kwargs): for k, v in last_embedding_keys.items(): state[k][-1] = v - # stuff to inject new layers into an existing model train over (not recommended, it doesnt amount to anything) - """ - if True: - remapped_dict = {} - remapped_indices = [ - (0, 1), - (1, 2), - (2, 3), - (3, 5), - (4, 6), - (5, 7), - (6, 9), - (7, 10), - (8, 11), - (9, 13), - (10, 14), - (11, 15), - ] - - for src, dst in remapped_indices: - remapped_dict[f"model.layers.{dst}.input_layernorm.weight"] = state[f"model.layers.{src}.input_layernorm.weight"] - remapped_dict[f"model.layers.{dst}.self_attn.k_proj.weight"] = state[f"model.layers.{src}.self_attn.k_proj.weight"] - remapped_dict[f"model.layers.{dst}.self_attn.q_proj.weight"] = state[f"model.layers.{src}.self_attn.q_proj.weight"] - remapped_dict[f"model.layers.{dst}.self_attn.v_proj.weight"] = state[f"model.layers.{src}.self_attn.v_proj.weight"] - remapped_dict[f"model.layers.{dst}.self_attn.o_proj.weight"] = state[f"model.layers.{src}.self_attn.o_proj.weight"] - remapped_dict[f"model.layers.{dst}.post_attention_layernorm.weight"] = state[f"model.layers.{src}.post_attention_layernorm.weight"] - remapped_dict[f"model.layers.{dst}.mlp.down_proj.weight"] = state[f"model.layers.{src}.mlp.down_proj.weight"] - remapped_dict[f"model.layers.{dst}.mlp.gate_proj.weight"] = state[f"model.layers.{src}.mlp.gate_proj.weight"] - remapped_dict[f"model.layers.{dst}.mlp.up_proj.weight"] = state[f"model.layers.{src}.mlp.up_proj.weight"] - - del state[f"model.layers.{src}.input_layernorm.weight"] - del state[f"model.layers.{src}.self_attn.k_proj.weight"] - del state[f"model.layers.{src}.self_attn.q_proj.weight"] - del state[f"model.layers.{src}.self_attn.v_proj.weight"] - del state[f"model.layers.{src}.self_attn.o_proj.weight"] - del state[f"model.layers.{src}.post_attention_layernorm.weight"] - del state[f"model.layers.{src}.mlp.down_proj.weight"] - del state[f"model.layers.{src}.mlp.gate_proj.weight"] - del state[f"model.layers.{src}.mlp.up_proj.weight"] - - for k, v in remapped_dict.items(): - state[k] = v - """ - model.load_state_dict(state, strict=cfg.trainer.strict_loading) # load lora weights if exists diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 795d868..f6bc4d6 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -162,6 +162,7 @@ class AR_NAR(Base): quant_levels[i] = prom.shape[-1] - 1 # apply token dropout error compensation + """ if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): steps = resps.shape[0] for l in range( quant_level ): @@ -171,6 +172,7 @@ class AR_NAR(Base): if random.random() < token_dropout_error: offset = 1 * ( 1 if random.random() < 0.5 else -1 ) resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 + """ # only apply stop token for RVQ level 0 if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None) or (self.predict_causally): @@ -1471,18 +1473,22 @@ def example_usage(): learning_rate = 0.01 optimizer = ml.Apollo - params["params"] = [{'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}] + params["params"] = [ + {'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'} + ] elif optimizer == "muon": - del params["params"] optimizer = ml.Muon - params["muon_params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ] - params["adamw_params"] = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] - params["adamw_params"] += [ param for name, param in model.named_parameters() if not name.startswith('model.') ] + muon_params = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ] + adamw_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] + adamw_params += [ param for name, param in model.named_parameters() if not name.startswith('model.') ] - if cfg.hyperparameters.optimizer_params is not None: - params["adamw_betas"] = cfg.hyperparameters.optimizer_params.pop("adamw_betas", (0.95, 0.95)) - params["adamw_eps"] = cfg.hyperparameters.optimizer_params.pop("adamw_eps", 1e-8) + params["params"] = [ + { "params": muon_params, "muon": True }, + { "params": adamw_params, "muon": False, "betas": (0.95, 0.95), "eps": 1e-8 }, + ] + elif optimizer == "cosmos": + optimizer = ml.COSMOS else: raise ValueError(f"Unrecognized optimizer: {optimizer}") diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 7be5b00..750a10c 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -91,15 +91,8 @@ def _get_offsets(): "resps|NAR:0:0": (16677, 17702), } -def _dropout_mask( input, p=None ): - # cosine scheduling - if p is None: - t = random.random() - p = math.cos(t * math.pi * 0.5) - - seq = [ random.random() < p for _ in range( input.shape[0] ) ] - mask = torch.tensor( seq, dtype=torch.bool, device=input.device ) - return mask +def _dropout_mask( input, p ): + return (torch.rand(input.shape[0], device=input.device) < p) def _create_mask(l, device): """1 is valid region and 0 is invalid.""" @@ -1383,13 +1376,14 @@ class Base(nn.Module): ) # apply token dropout + """ if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): steps = embedding.shape[0] - (1 if quant_level == 0 else 0) # do not mess with stop token for i in range( steps ): if random.random() > token_dropout_rate: continue - embedding[i] = self.dropout_token + """ elif name == "timestep" and self.time_emb is not None: embedding = self.time_emb( input ) elif name == "len" and self.len_emb is not None: diff --git a/vall_e/utils/ext/muon.py b/vall_e/utils/ext/muon.py index 0bb4a9f..83292e8 100644 --- a/vall_e/utils/ext/muon.py +++ b/vall_e/utils/ext/muon.py @@ -66,15 +66,14 @@ class Muon(torch.optim.Optimizer): def __init__( self, + params=None, lr=1e-3, wd=0.1, - muon_params=None, momentum=0.95, nesterov=True, ns_steps=5, - adamw_params=None, - adamw_betas=(0.95, 0.95), - adamw_eps=1e-8, + betas=(0.95, 0.95), + eps=1e-8, ): defaults = dict( @@ -83,22 +82,12 @@ class Muon(torch.optim.Optimizer): momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, + betas=betas, + eps=eps, + muon=False, ) - params = list(muon_params) - adamw_params = list(adamw_params) if adamw_params is not None else [] - params.extend(adamw_params) super().__init__(params, defaults) - # Sort parameters into those for which we will use Muon, and those for which we will not - for p in muon_params: - # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer - assert p.ndim == 2, p.ndim - self.state[p]["use_muon"] = True - for p in adamw_params: - # Do not use Muon for parameters in adamw_params - self.state[p]["use_muon"] = False def adjust_lr_for_muon(self, lr, param_shape): A, B = param_shape[:2] @@ -125,80 +114,74 @@ class Muon(torch.optim.Optimizer): ############################ # Muon # ############################ + if group["muon"]: + # import pdb; pdb.set_trace() + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] - # this actually doesn't work with deepspeed for the same reason APOLLO required modifications: - # deepspeed's BF16/F16 optimizer wrapper modifies the tensors, so self.state loses the right mapping - # can't be assed to figure it out right now since it's not easy to fix like APOLLO - - params = [p for p in group["params"] if self.state[p]["use_muon"]] - # import pdb; pdb.set_trace() - lr = group["lr"] - wd = group["wd"] - momentum = group["momentum"] + # generate weight updates in distributed fashion + for p in group["params"]: + # sanity check + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None - # generate weight updates in distributed fashion - for p in params: - # sanity check - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - # scale update - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) - # apply weight decay - p.data.mul_(1 - lr * wd) - - # apply update - p.data.add_(u, alpha=-adjusted_lr) + # apply update + p.data.add_(u, alpha=-adjusted_lr) ############################ # AdamW backup # ############################ + else: + lr = group['lr'] + beta1, beta2 = group["betas"] + eps = group["eps"] + weight_decay = group["wd"] - params = [p for p in group["params"] if not self.state[p]["use_muon"]] - lr = group['lr'] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["wd"] + for p in group["params"]: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - if "step" not in state: - state["step"] = 0 - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - state["step"] += 1 - step = state["step"] - buf1 = state["moment1"] - buf2 = state["moment2"] - buf1.lerp_(g, 1 - beta1) - buf2.lerp_(g.square(), 1 - beta2) + g = buf1 / (eps + buf2.sqrt()) - g = buf1 / (eps + buf2.sqrt()) - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - scale = bias_correction1 / bias_correction2**0.5 - p.data.mul_(1 - lr * weight_decay) - p.data.add_(g, alpha=-lr / scale) + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) return loss \ No newline at end of file