diff --git a/README.md b/README.md index d423c09..bd2d96d 100755 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ Training is very dependent on: * the quality of your dataset. * how much data you have. * the bandwidth you quantized your audio to. +* the underlying model architecture used ### Pre-Processed Dataset diff --git a/setup.py b/setup.py index 8bf22e6..2b48f40 100755 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ setup( "omegaconf==2.0.6", "tqdm>=4.64.1", "humanize>=4.4.0", + "transformer>4.36.0", "pandas>=1.5.0", "torch>=1.13.0", diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 7103409..eec9c34 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -132,10 +132,13 @@ class AR_NAR(Base): quant_levels = torch.Tensor([ generate(0, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16) else: + quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + """ if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None: quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) else: quant_levels = torch.Tensor([ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ]) + """ targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target) resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding @@ -338,7 +341,7 @@ def example_usage(): 'd_model': 256, 'n_heads': 4, 'n_layers': 12, - 'n_experts': 1, + 'n_experts': 8, } """ @@ -349,7 +352,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 250 + steps = 500 optimizer = ml.Prodigy(model.parameters(), lr=1.0) #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) engine = Engine(model=model, optimizer=optimizer) @@ -385,6 +388,10 @@ def example_usage(): tqdm.write(f"{stats}") + torch.save( { + 'module': model.state_dict() + }, "./data/test.pth" ) + sample("init", 5) train() sample("final") diff --git a/vall_e/models/base.py b/vall_e/models/base.py index f97c5fc..48b2c7b 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -14,10 +14,32 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.checkpoint import checkpoint from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision -from .retnet import RetNetDecoder, RetNetConfig -from .transformer import SinusoidalEmbedding, Block as TransformerBlock from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample +try: + from .transformer import SinusoidalEmbedding, Block as TransformerBlock +except Exception as e: + print("Error importing `transformer` arch:", e) + pass + +try: + from .retnet import RetNetDecoder, RetNetConfig +except Exception as e: + print("Error importing `retnet` arch:", e) + pass + +try: + from transformers import LlamaModel, LlamaConfig +except Exception as e: + print("Error importing `llama` arch:", e) + pass + +try: + from transformers import MixtralModel, MixtralConfig + from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func +except Exception as e: + print("Error importing `mixtral` arch:", e) + def _create_mask(l, device): """1 is valid region and 0 is invalid.""" seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) @@ -254,9 +276,40 @@ class Base(nn.Module): norm_type=self.norm_type, n_levels=self.n_resp_levels, ) for _ in range(n_layers) ]) + elif self.arch_type == "llama": + if n_experts <= 1: + self.model = LlamaModel(LlamaConfig( + vocab_size=n_resp_tokens, + hidden_size=d_model, + max_position_embeddings=75 * 60, # max-length of 60 seconds + intermediate_size=d_model*4, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + attention_dropout=p_dropout, + num_key_value_heads=n_heads, + hidden_act="gelu", + is_encoder_decoder=False, + is_decoder=True, + )) + else: + self.model = MixtralModel(MixtralConfig( + vocab_size =n_resp_tokens, + hidden_size=d_model, + max_position_embeddings=75 * 60, # max-length of 60 seconds + intermediate_size=d_model*4, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + attention_dropout=p_dropout, + num_key_value_heads=n_heads, + hidden_act="gelu", + is_encoder_decoder=False, + is_decoder=True, + num_local_experts=n_experts, + num_experts_per_tok=min(2, n_experts), + )) elif self.arch_type == "retnet": - self.retnet = RetNetDecoder(RetNetConfig( - vocab_size=n_tokens, + self.model = RetNetDecoder(RetNetConfig( + vocab_size=n_resp_tokens, decoder_embed_dim=d_model, decoder_value_embed_dim =d_model * 2, decoder_retention_heads=n_heads, @@ -278,8 +331,9 @@ class Base(nn.Module): # MoE use_xmoe=n_experts>1, - moe_freq=2, + moe_freq=1, moe_expert_count=n_experts, + moe_gating_use_fp32=False, )) self.classifier = nn.Linear(d_model, n_resp_tokens) @@ -326,6 +380,7 @@ class Base(nn.Module): ) x, m = list_to_tensor(x_list) + aux_loss = None device = x.device @@ -336,12 +391,26 @@ class Base(nn.Module): # run the initial prompt to fill the KV cache for n in range(prefill_size): xi = x[:, n, :].unsqueeze(1) - self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True) + self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True) # grab last token(s) x = x[:, -1, :].unsqueeze(1) + # HF transformer derived model + elif self.arch_type == "llama": + kwargs = dict( + #attention_mask=m, + inputs_embeds=x, + ) + if self.n_experts > 1: + kwargs["output_router_logits"] = True - if self.arch_type == "transformer": + t = self.model(**kwargs) + x = t[0] + + if self.n_experts > 1: + router_logits = t[-1] + aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok ) + elif self.arch_type == "transformer": # ensures we specify a quant_level for the transformer implementation's AdaLN l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels l = l.to(device) @@ -352,8 +421,9 @@ class Base(nn.Module): x = block(x, m, l) elif self.arch_type == "retnet": # pass our inputs through the RetNet - x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) - + x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True) + if _ is not None and "l_aux" in _: + aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001 # output projection layer with masking x = self.classifier(x) * m @@ -390,6 +460,9 @@ class Base(nn.Module): acc = self.accuracy_metric( inputs, target ), precision = self.precision_metric( inputs, target ), ) + + if aux_loss is not None: + self.loss["nll"] += aux_loss return logits