diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index 5f72148..6a7966e 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -580,7 +580,7 @@ std::vector generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i int32_t seq_len = n_outputs; int32_t top_k = 0; float top_p = 1.0; - float temperature = 1.5f; + float temperature = 1.0f; float cfg_strength = 3.0f; float start_noise = 0.0f; float end_noise = 1.0f; diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 98b021f..c172dfd 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -573,6 +573,9 @@ class Model(LlamaPreTrainedModel): self.vocab_size = config.vocab_size self.layers_n = config.num_hidden_layers + if self.vocab_size: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( [DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 2453e8d..4bb2954 100644 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -468,7 +468,7 @@ class Base(nn.Module): self.model_config = LlamaConfig( - vocab_size=n_vocab, + vocab_size=0, # n_vocab, hidden_size=d_model, max_position_embeddings=max_position_embeddings, intermediate_size=d_model*d_ffn, @@ -1435,15 +1435,15 @@ if __name__ == "__main__": from ..models import download_model, DEFAULT_MODEL_PATH from ..emb.qnt import decode_to_file - from ..utils.io import torch_load + from ..utils.io import torch_load, torch_save # hack in a non-causal mask def _update_noncausal_mask( attention_mask, inputs_embeds, cache_positions, - past_key_values_length, - output_attentions, + past_key_values_length=0, + output_attentions=False, ): # create noncausal mask # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] @@ -1464,18 +1464,31 @@ if __name__ == "__main__": device = "cuda" dtype = torch.bfloat16 + # the pretrained model is botched is_from_pretrained = True + kludge_export = True if is_from_pretrained: - # tokenizer = LlamaTokenizer.from_pretrained("ecker/vall-e", revision="hf") + kludge_export = False hf_model = LlamaForCausalLM.from_pretrained("ecker/vall-e", revision="hf") hf_model.to(device=device, dtype=dtype) hf_model.eval() model = hf_model.model else: + class LlamaForCausalLM(nn.Module): + def __init__(self, config): + super().__init__() + + self.model = LlamaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward( *args, **kwargs ): + return self.model( *args, **kwargs ) + download_model() - model = LlamaModel(LlamaConfig( - vocab_size=1024, + + hf_model = LlamaForCausalLM(LlamaConfig( + vocab_size=17702, hidden_size=1024, max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds intermediate_size=1024*4, @@ -1483,11 +1496,12 @@ if __name__ == "__main__": num_attention_heads=16, attention_dropout=0.0, num_key_value_heads=16, - sliding_window=75 * 12, # 12 second context window + sliding_window=None, # 75 * 12, # 12 second context window hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, )) + model = hf_model.model state_dict = torch_load(DEFAULT_MODEL_PATH)['module'] state_dict_model = {} @@ -1497,8 +1511,9 @@ if __name__ == "__main__": state_dict_model[k.replace("model.", "")] = v model.load_state_dict( state_dict_model, strict=False ) - model.to(device=device, dtype=dtype) - model.eval() + + hf_model.to(device=device, dtype=dtype) + hf_model.eval() model._original_update_causal_mask = model._update_causal_mask model._update_noncausal_mask = _update_noncausal_mask @@ -1531,7 +1546,8 @@ if __name__ == "__main__": # name, (start, end), classifier, src_name io_map = { - 'text': [(0, 256), 9, "text_emb.weight"], + 'phn': [(0, 256), 9, "text_emb.weight"], + #'text': [(0, 256), 9, "raw_text_emb.weight"], 'rvq_l': [(256, 264), -1, "rvq_l_emb.weight"], 'lang': [(264, 270), -1, "langs_emb.weight"], 'task': [(270, 279), -1, "tasks_emb.weight"], @@ -1574,6 +1590,7 @@ if __name__ == "__main__": heads = {} n_embd = 1024 + f_it = 0 with torch.no_grad(): for k, v in io_map.items(): start, end = v[0] @@ -1583,7 +1600,7 @@ if __name__ == "__main__": if is_from_pretrained: n_vocab = end - start - embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight) + embds[k] = torch.nn.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight) embds[k].weight[:] = model.embed_tokens.weight[start:end, :] if classifier_idx >= 0: @@ -1595,15 +1612,34 @@ if __name__ == "__main__": heads[k].weight[:] = hf_model.lm_head.weight[start:end, :] else: embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name] - embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype) + embds[k] = torch.nn.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype) embds[k].load_state_dict({ "weight": embd_weight }) + + vocab = embd_weight.shape[0] + start = f_it + f_it += vocab + end = f_it + + if kludge_export: + model.embed_tokens.weight[start:end, :] = embds[k].weight if classifier_idx >= 0: + # NAR:0:0 does not have a masked token output + if k == "resp|NAR:0:0": + end -= 1 + head_weight = state_dict[f'classifiers.proj.{classifier_idx}.weight'] heads[k] = torch.nn.Linear( head_weight.shape[1], head_weight.shape[0], bias=False ).to(device=device, dtype=dtype) heads[k].load_state_dict({ "weight": head_weight }) + if kludge_export: + hf_model.lm_head.weight[start:end, :] = heads[k].weight + + if kludge_export: + state_dict = hf_model.state_dict() + torch_save({ "module": state_dict, "format": "pt" }, "./data/model.safetensors") + def create_inputs( phn, prom, lang=0, seq=None, mode="AR:0:0" ): rvq_l = mode_lvl_map[mode] @@ -1637,7 +1673,9 @@ if __name__ == "__main__": break prom_embd += embds[f"prom|{i}"](p) - if seq is not None: + if isinstance( seq, list ) and not seq: + ... + elif seq is not None: if mode == "len": seq_embd = embds["len"](seq) elif mode == "AR:0:0": @@ -1705,7 +1743,7 @@ if __name__ == "__main__": # test len inferencing - print( "len:", generate( phn, prom, mode="len" ) ) + #print( "len:", generate( phn, prom, mode="len" ) ) # test ar ouptut if resp: diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 1610b7e..ae262e9 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -450,7 +450,7 @@ class Base_V2(nn.Module): self.model = None elif self.arch_type in ["llama"]: self.model_config = LlamaConfig( - vocab_size=n_vocab, + vocab_size=0, # n_vocab, hidden_size=d_model, max_position_embeddings=max_position_embeddings, intermediate_size=d_model*d_ffn,