From 3da1518ace29879d494d94adfeef8284fb1309f5 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 31 Jan 2024 21:48:36 -0600 Subject: [PATCH] added Mistral (non-Mixtral) backend, useless optimization when not training, proper adjustment of the LR for Prodigyopt through d_coeff (maybe), recurrent sampling for LLaMA/Mistral/Mixtral backends (again, doesn't actually work) --- vall_e/engines/__init__.py | 7 +++- vall_e/engines/base.py | 8 +++- vall_e/engines/deepspeed.py | 2 +- vall_e/inference.py | 4 +- vall_e/models/__init__.py | 7 ++-- vall_e/models/ar_nar.py | 25 +++++++++---- vall_e/models/base.py | 75 +++++++++++++++++++++++++++++++------ 7 files changed, 98 insertions(+), 30 deletions(-) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index d105b02..940acf3 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -25,8 +25,8 @@ except Exception as e: from functools import cache @cache -def load_engines(): - models = get_models(cfg.models.get()) +def load_engines(training=True): + models = get_models(cfg.models.get(), training=training) engines = dict() for name, model in models.items(): @@ -59,6 +59,9 @@ def load_engines(): optimizer = ml.SGD elif cfg.hyperparameters.optimizer.lower() == "prodigy": optimizer_class = ml.Prodigy + + params['d_coef'] = params['lr'] + params['lr'] = 1.0 else: raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index d51a273..70a97ec 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -193,13 +193,17 @@ class Engine(): def get_lr(self): lrs = [] for param_group in self.optimizer.param_groups: - if 'lr' in param_group: + if 'd_coeff' in param_group: + lrs.append(param_group['d_coeff']) + elif 'lr' in param_group: lrs.append(param_group['lr']) return lrs def set_lr(self, lr): for param_group in self.optimizer.param_groups: - if 'lr' in param_group: + if 'd_coeff' in param_group: + param_group['d_coeff'] = lr + elif 'lr' in param_group: param_group['lr'] = lr def get_global_grad_norm(self): diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index a00aa86..59acdea 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -99,7 +99,7 @@ class Engine(DeepSpeedEngine): try: if hasattr(self.optimizer, 'param_groups'): for param_group in self.optimizer.param_groups: - param_group['lr'] = lr + param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr else: self.optimizer.set_lr(lr) except Exception as e: diff --git a/vall_e/inference.py b/vall_e/inference.py index c45ea4e..0b72024 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -73,7 +73,7 @@ class TTS(): self.ar_ckpt = ar_ckpt self.nar_ckpt = nar_ckpt - models = get_models(cfg.models.get()) + models = get_models(cfg.models.get(), training=False) for name, model in models.items(): if name.startswith("ar"): @@ -101,7 +101,7 @@ class TTS(): self.loading = False def load_models( self ): - engines = load_engines() + engines = load_engines(training=False) for name, engine in engines.items(): if name.startswith("ar"): self.ar = engine.module diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 1fadf2c..8874d0b 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -2,7 +2,7 @@ from .ar import AR from .nar import NAR from .ar_nar import AR_NAR -def get_model(cfg): +def get_model(cfg, training=True): if cfg.name == "ar": Model = AR elif cfg.name == "nar": @@ -20,6 +20,7 @@ def get_model(cfg): n_layers=cfg.layers, n_experts=cfg.experts, + training=training, config = cfg, ) model._cfg = cfg @@ -28,5 +29,5 @@ def get_model(cfg): return model -def get_models(models): - return { model.full_name: get_model(model) for model in models } +def get_models(models, training=True): + return { model.full_name: get_model(model, training=training) for model in models } diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 05808e5..7b7ee9a 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -120,7 +120,7 @@ class AR_NAR(Base): if n_levels == self.n_resp_levels: # might be better to have this decided on the dataloader level - if cfg.experimental: + if cfg.experimental and False: # makes higher levels less likely def generate( lo=0, hi=8 ): index = lo @@ -228,13 +228,22 @@ class AR_NAR(Base): else: resps_list = self._unsqueeze_list(sequence_list) - logits = super().forward( - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - lang_list=lang_list, - state=recurrent_state - ) + if recurrent_state is not None: + logits, recurrent_state = super().forward( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + lang_list=lang_list, + state=recurrent_state + ) + else: + logits = super().forward( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + lang_list=lang_list, + state=recurrent_state + ) r = super().sample( logits=logits, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index e872687..220751c 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -34,6 +34,12 @@ except Exception as e: print("Error importing `llama` arch:", e) pass +try: + from transformers import MistralModel, MistralConfig +except Exception as e: + print("Error importing `mistral` arch:", e) + pass + try: from transformers import MixtralModel, MixtralConfig from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock @@ -269,9 +275,11 @@ class Base(nn.Module): n_experts: int=1, + training = True, config = None, ): super().__init__() + self.training = training self.config = config self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True @@ -312,21 +320,21 @@ class Base(nn.Module): self.blocks = nn.ModuleList([TransformerBlock( d_model=d_model, n_heads=n_heads, - p_dropout=p_dropout, + p_dropout=p_dropout if training else 0.0, causal=self.causal, norm_type=self.norm_type, n_levels=self.n_resp_levels, ) for _ in range(n_layers) ]) - elif self.arch_type == "llama": + elif self.arch_type == "mistral": if n_experts <= 1: - self.model = LlamaModel(LlamaConfig( + self.model = MistralModel(MistralConfig( vocab_size=n_resp_tokens, hidden_size=d_model, max_position_embeddings=75 * 60, # max-length of 60 seconds intermediate_size=d_model*4, num_hidden_layers=n_layers, num_attention_heads=n_heads, - attention_dropout=p_dropout, + attention_dropout=p_dropout if training else 0.0, num_key_value_heads=n_heads, hidden_act="gelu", is_encoder_decoder=False, @@ -340,14 +348,51 @@ class Base(nn.Module): intermediate_size=d_model*4, num_hidden_layers=n_layers, num_attention_heads=n_heads, - attention_dropout=p_dropout, + attention_dropout=p_dropout if training else 0.0, num_key_value_heads=n_heads, + sliding_window=75 * 12, # 12 second context window + output_router_logits=training, hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), )) + elif self.arch_type == "llama": + if n_experts <= 1: + self.model = LlamaModel(LlamaConfig( + vocab_size=n_resp_tokens, + hidden_size=d_model, + max_position_embeddings=75 * 60, # max-length of 60 seconds + intermediate_size=d_model*4, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + attention_dropout=p_dropout if training else 0.0, + num_key_value_heads=n_heads, + sliding_window=75 * 12, # 12 second context window + hidden_act="gelu", + is_encoder_decoder=False, + is_decoder=True, + )) + else: + self.model = MixtralModel(MixtralConfig( + vocab_size =n_resp_tokens, + hidden_size=d_model, + max_position_embeddings=75 * 60, # max-length of 60 seconds + intermediate_size=d_model*4, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + attention_dropout=p_dropout if training else 0.0, + num_key_value_heads=n_heads, + sliding_window=75 * 12, # 12 second context window + output_router_logits=training, + hidden_act="gelu", + is_encoder_decoder=False, + is_decoder=True, + num_local_experts=n_experts, + num_experts_per_tok=min(2, n_experts), + )) + elif self.arch_type == "retnet": kwargs = dict( vocab_size=n_resp_tokens, @@ -356,7 +401,7 @@ class Base(nn.Module): decoder_retention_heads=n_heads, decoder_ffn_embed_dim=d_model * 4, decoder_layers=n_layers, - dropout=p_dropout, + dropout=p_dropout if training else 0.0, checkpoint_activations=self.activation_checkpointing, activation_fn="gelu", use_layernorm=True, # self.version < 3, @@ -409,7 +454,7 @@ class Base(nn.Module): lang_list: list[Tensor] | None = None, quant_levels: Tensor | None = None, - state: dict | None = None, + state: dict | list | None = None, ): batch_size = len(text_list) @@ -441,18 +486,25 @@ class Base(nn.Module): # grab last token(s) x = x[:, -1, :].unsqueeze(1) # HF transformer derived model - elif self.arch_type == "llama": + elif self.arch_type == "llama" or self.arch_type == "mistral": kwargs = dict( #attention_mask=m, inputs_embeds=x, + past_key_values=state, + use_cache=state is not None, + # return_dict=True, ) - if self.n_experts > 1: + if self.n_experts > 1 and targ_list is not None: kwargs["output_router_logits"] = True t = self.model(**kwargs) + x = t[0] - if self.n_experts > 1: + if state is not None: + state = t[1] + + if self.n_experts > 1 and targ_list is not None: router_logits = t[-1] aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok ) elif self.arch_type == "transformer": @@ -477,7 +529,6 @@ class Base(nn.Module): # compute loss if the target is given if targ_list is not None: - target_list = self._samplewise_merge_tensors( text_list, lang_list, @@ -509,7 +560,7 @@ class Base(nn.Module): if aux_loss is not None: self.loss["nll"] += aux_loss - return logits + return (logits, state) if state is not None else logits def sample( self,