diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index 0c8dae8..ede485b 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -14,7 +14,7 @@ #include #include -#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions +#define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions #define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch #if !LLAMA_CPP_EXTENDED diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index b5c2b38..a5bfd39 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -109,6 +109,8 @@ def load_engines(training=True, **model_kwargs): optimizer_class = None scheduler_class = None + # model.config.frozen_params = ['sep', 'dropout_token', 'text_emb.weight', 'proms_emb.embeddings.0.weight', 'proms_emb.embeddings.1.weight', 'proms_emb.embeddings.2.weight', 'proms_emb.embeddings.3.weight', 'proms_emb.embeddings.4.weight', 'proms_emb.embeddings.5.weight', 'proms_emb.embeddings.6.weight', 'proms_emb.embeddings.7.weight', 'resps_emb.embeddings.0.weight', 'resps_emb.embeddings.1.weight', 'resps_emb.embeddings.2.weight', 'resps_emb.embeddings.3.weight', 'resps_emb.embeddings.4.weight', 'resps_emb.embeddings.5.weight', 'resps_emb.embeddings.6.weight', 'resps_emb.embeddings.7.weight', 'resps_emb.embeddings.8.weight', 'langs_emb.weight', 'tasks_emb.weight', 'tones_emb.weight', 'rvq_l_emb.weight', 'len_emb.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.norm.weight'] + params = { "params": [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ], "lr": cfg.hyperparameters.learning_rate, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 42c3770..e2a3b42 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1436,6 +1436,9 @@ class Base(nn.Module): continue # offset to flattened vocab ranges + if self.classifier is not None: + compute_acc = False + """ if self.classifier is not None: offsets = _get_offsets() @@ -1454,6 +1457,7 @@ class Base(nn.Module): if t == self.ignore_index: continue token[i] += start + """ if token.is_floating_point(): ignored = True @@ -1679,31 +1683,30 @@ class Base(nn.Module): for i, state in enumerate( hidden_states ): hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ] + # de-offset if needed + if self.classifier is not None: + offsets = _get_offsets() + for batch_index, classifier_level in enumerate( classifier_levels ): + if classifier_level == "stt": + k = "text" + elif classifier_level == "len": + k = "len" + else: + k = f'resps|{classifier_level}' + + if k not in offsets: + continue + + start, end = offsets[k] + + logits[batch_index] = logits[batch_index][:, start:start+end] + if not training: loss = None stats = None self.loss = None self.stats = None - - # de-offset if needed - if self.classifier is not None: - offsets = _get_offsets() - for batch_index, classifier_level in enumerate( classifier_levels ): - if classifier_level == "stt": - k = "text" - elif classifier_level == "len": - k = "len" - else: - k = f'resps|{classifier_level}' - - if k not in offsets: - continue - - start, end = offsets[k] - - logits[batch_index] = logits[batch_index][:, start:start+end] - # compute loss if the target is given else: loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )