diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index 0c8dae8..ede485b 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -14,7 +14,7 @@ #include #include -#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions +#define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions #define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch #if !LLAMA_CPP_EXTENDED diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index b5c2b38..144f10c 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -109,6 +109,8 @@ def load_engines(training=True, **model_kwargs): optimizer_class = None scheduler_class = None + # model.config.frozen_params = ['sep', 'dropout_token', 'text_emb.weight', 'proms_emb.embeddings.0.weight', 'proms_emb.embeddings.1.weight', 'proms_emb.embeddings.2.weight', 'proms_emb.embeddings.3.weight', 'proms_emb.embeddings.4.weight', 'proms_emb.embeddings.5.weight', 'proms_emb.embeddings.6.weight', 'proms_emb.embeddings.7.weight', 'resps_emb.embeddings.0.weight', 'resps_emb.embeddings.1.weight', 'resps_emb.embeddings.2.weight', 'resps_emb.embeddings.3.weight', 'resps_emb.embeddings.4.weight', 'resps_emb.embeddings.5.weight', 'resps_emb.embeddings.6.weight', 'resps_emb.embeddings.7.weight', 'resps_emb.embeddings.8.weight', 'langs_emb.weight', 'tasks_emb.weight', 'tones_emb.weight', 'rvq_l_emb.weight', 'len_emb.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.norm.weight'] + params = { "params": [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ], "lr": cfg.hyperparameters.learning_rate, @@ -235,8 +237,10 @@ def load_engines(training=True, **model_kwargs): ("rvq_l_emb.weight", model.config.resp_levels ), ("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ), ("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ), - ("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ), - ("classifiers.proj.0.bias" if model.config.experimental.split_classifiers else 'classifier.bias', model.config.audio_tokens + uses_stop_token ), + ("classifiers.proj.0.weight", model.config.audio_tokens + uses_stop_token ), + ("classifiers.proj.0.bias", model.config.audio_tokens + uses_stop_token ), + ("classifier.weight", model.n_vocab ), + ("classifier.bias", model.n_vocab ), ] # correcting an oversight diff --git a/vall_e/export.py b/vall_e/export.py index ab0d03c..e7fd3b0 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -105,7 +105,8 @@ def convert_to_hf_llama( state_dict, config = None, save_path = None ): classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias ) if not split_classifiers: - classifier.weight[:] = state_dict['module']['classifier.weight'][:] + src = state_dict['module']['classifier.weight'][:] + classifier.weight[:src.shape[0], ] = src # update ranges start = 0 diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 42c3770..6fe527c 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -523,7 +523,7 @@ class Base(nn.Module): classifier_l_tokens += [ 11 ] classifier_l_names += ["len"] - n_vocab = 17701 if not split_classifiers else n_resp_tokens + 1 + n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1 self.n_vocab = n_vocab self.unified_position_ids = unified_position_ids @@ -1436,6 +1436,9 @@ class Base(nn.Module): continue # offset to flattened vocab ranges + if self.classifier is not None: + compute_acc = False + """ if self.classifier is not None: offsets = _get_offsets() @@ -1454,6 +1457,7 @@ class Base(nn.Module): if t == self.ignore_index: continue token[i] += start + """ if token.is_floating_point(): ignored = True @@ -1511,7 +1515,7 @@ class Base(nn.Module): # perofrm loss calculation on the entire sequence if not self.config.loss_factors: - target = _join( target, torch.tensor(self.ignore_index if self.classifier is None else 17685, device=target[-1].device) ) + target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) logit = logits[batch_index] # shift if causal @@ -1679,31 +1683,30 @@ class Base(nn.Module): for i, state in enumerate( hidden_states ): hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ] + # de-offset if needed + if self.classifier is not None: + offsets = _get_offsets() + for batch_index, classifier_level in enumerate( classifier_levels ): + if classifier_level == "stt": + k = "text" + elif classifier_level == "len": + k = "len" + else: + k = f'resps|{classifier_level}' + + if k not in offsets: + continue + + start, end = offsets[k] + + logits[batch_index] = logits[batch_index][:, start:end] + if not training: loss = None stats = None self.loss = None self.stats = None - - # de-offset if needed - if self.classifier is not None: - offsets = _get_offsets() - for batch_index, classifier_level in enumerate( classifier_levels ): - if classifier_level == "stt": - k = "text" - elif classifier_level == "len": - k = "len" - else: - k = f'resps|{classifier_level}' - - if k not in offsets: - continue - - start, end = offsets[k] - - logits[batch_index] = logits[batch_index][:, start:start+end] - # compute loss if the target is given else: loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )