diagnosed both hf/llama.cpp versions to probably just being a faulty export method (to-do: migrate vall_e.models.base to vall_e.export --hf)

This commit is contained in:
mrq 2025-04-05 22:05:39 -05:00
parent c34763769a
commit 1e22519d94
4 changed files with 58 additions and 17 deletions

View File

@ -580,7 +580,7 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
int32_t seq_len = n_outputs; int32_t seq_len = n_outputs;
int32_t top_k = 0; int32_t top_k = 0;
float top_p = 1.0; float top_p = 1.0;
float temperature = 1.5f; float temperature = 1.0f;
float cfg_strength = 3.0f; float cfg_strength = 3.0f;
float start_noise = 0.0f; float start_noise = 0.0f;
float end_noise = 1.0f; float end_noise = 1.0f;

View File

@ -573,6 +573,9 @@ class Model(LlamaPreTrainedModel):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.layers_n = config.num_hidden_layers self.layers_n = config.num_hidden_layers
if self.vocab_size:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )

View File

@ -468,7 +468,7 @@ class Base(nn.Module):
self.model_config = LlamaConfig( self.model_config = LlamaConfig(
vocab_size=n_vocab, vocab_size=0, # n_vocab,
hidden_size=d_model, hidden_size=d_model,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
intermediate_size=d_model*d_ffn, intermediate_size=d_model*d_ffn,
@ -1435,15 +1435,15 @@ if __name__ == "__main__":
from ..models import download_model, DEFAULT_MODEL_PATH from ..models import download_model, DEFAULT_MODEL_PATH
from ..emb.qnt import decode_to_file from ..emb.qnt import decode_to_file
from ..utils.io import torch_load from ..utils.io import torch_load, torch_save
# hack in a non-causal mask # hack in a non-causal mask
def _update_noncausal_mask( def _update_noncausal_mask(
attention_mask, attention_mask,
inputs_embeds, inputs_embeds,
cache_positions, cache_positions,
past_key_values_length, past_key_values_length=0,
output_attentions, output_attentions=False,
): ):
# create noncausal mask # create noncausal mask
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
@ -1464,18 +1464,31 @@ if __name__ == "__main__":
device = "cuda" device = "cuda"
dtype = torch.bfloat16 dtype = torch.bfloat16
# the pretrained model is botched
is_from_pretrained = True is_from_pretrained = True
kludge_export = True
if is_from_pretrained: if is_from_pretrained:
# tokenizer = LlamaTokenizer.from_pretrained("ecker/vall-e", revision="hf") kludge_export = False
hf_model = LlamaForCausalLM.from_pretrained("ecker/vall-e", revision="hf") hf_model = LlamaForCausalLM.from_pretrained("ecker/vall-e", revision="hf")
hf_model.to(device=device, dtype=dtype) hf_model.to(device=device, dtype=dtype)
hf_model.eval() hf_model.eval()
model = hf_model.model model = hf_model.model
else: else:
class LlamaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward( *args, **kwargs ):
return self.model( *args, **kwargs )
download_model() download_model()
model = LlamaModel(LlamaConfig(
vocab_size=1024, hf_model = LlamaForCausalLM(LlamaConfig(
vocab_size=17702,
hidden_size=1024, hidden_size=1024,
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
intermediate_size=1024*4, intermediate_size=1024*4,
@ -1483,11 +1496,12 @@ if __name__ == "__main__":
num_attention_heads=16, num_attention_heads=16,
attention_dropout=0.0, attention_dropout=0.0,
num_key_value_heads=16, num_key_value_heads=16,
sliding_window=75 * 12, # 12 second context window sliding_window=None, # 75 * 12, # 12 second context window
hidden_act="gelu", hidden_act="gelu",
is_encoder_decoder=False, is_encoder_decoder=False,
is_decoder=True, is_decoder=True,
)) ))
model = hf_model.model
state_dict = torch_load(DEFAULT_MODEL_PATH)['module'] state_dict = torch_load(DEFAULT_MODEL_PATH)['module']
state_dict_model = {} state_dict_model = {}
@ -1497,8 +1511,9 @@ if __name__ == "__main__":
state_dict_model[k.replace("model.", "")] = v state_dict_model[k.replace("model.", "")] = v
model.load_state_dict( state_dict_model, strict=False ) model.load_state_dict( state_dict_model, strict=False )
model.to(device=device, dtype=dtype)
model.eval() hf_model.to(device=device, dtype=dtype)
hf_model.eval()
model._original_update_causal_mask = model._update_causal_mask model._original_update_causal_mask = model._update_causal_mask
model._update_noncausal_mask = _update_noncausal_mask model._update_noncausal_mask = _update_noncausal_mask
@ -1531,7 +1546,8 @@ if __name__ == "__main__":
# name, (start, end), classifier, src_name # name, (start, end), classifier, src_name
io_map = { io_map = {
'text': [(0, 256), 9, "text_emb.weight"], 'phn': [(0, 256), 9, "text_emb.weight"],
#'text': [(0, 256), 9, "raw_text_emb.weight"],
'rvq_l': [(256, 264), -1, "rvq_l_emb.weight"], 'rvq_l': [(256, 264), -1, "rvq_l_emb.weight"],
'lang': [(264, 270), -1, "langs_emb.weight"], 'lang': [(264, 270), -1, "langs_emb.weight"],
'task': [(270, 279), -1, "tasks_emb.weight"], 'task': [(270, 279), -1, "tasks_emb.weight"],
@ -1574,6 +1590,7 @@ if __name__ == "__main__":
heads = {} heads = {}
n_embd = 1024 n_embd = 1024
f_it = 0
with torch.no_grad(): with torch.no_grad():
for k, v in io_map.items(): for k, v in io_map.items():
start, end = v[0] start, end = v[0]
@ -1583,7 +1600,7 @@ if __name__ == "__main__":
if is_from_pretrained: if is_from_pretrained:
n_vocab = end - start n_vocab = end - start
embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight) embds[k] = torch.nn.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
embds[k].weight[:] = model.embed_tokens.weight[start:end, :] embds[k].weight[:] = model.embed_tokens.weight[start:end, :]
if classifier_idx >= 0: if classifier_idx >= 0:
@ -1595,15 +1612,34 @@ if __name__ == "__main__":
heads[k].weight[:] = hf_model.lm_head.weight[start:end, :] heads[k].weight[:] = hf_model.lm_head.weight[start:end, :]
else: else:
embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name] embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name]
embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype) embds[k] = torch.nn.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
embds[k].load_state_dict({ "weight": embd_weight }) embds[k].load_state_dict({ "weight": embd_weight })
vocab = embd_weight.shape[0]
start = f_it
f_it += vocab
end = f_it
if kludge_export:
model.embed_tokens.weight[start:end, :] = embds[k].weight
if classifier_idx >= 0: if classifier_idx >= 0:
# NAR:0:0 does not have a masked token output
if k == "resp|NAR:0:0":
end -= 1
head_weight = state_dict[f'classifiers.proj.{classifier_idx}.weight'] head_weight = state_dict[f'classifiers.proj.{classifier_idx}.weight']
heads[k] = torch.nn.Linear( head_weight.shape[1], head_weight.shape[0], bias=False ).to(device=device, dtype=dtype) heads[k] = torch.nn.Linear( head_weight.shape[1], head_weight.shape[0], bias=False ).to(device=device, dtype=dtype)
heads[k].load_state_dict({ "weight": head_weight }) heads[k].load_state_dict({ "weight": head_weight })
if kludge_export:
hf_model.lm_head.weight[start:end, :] = heads[k].weight
if kludge_export:
state_dict = hf_model.state_dict()
torch_save({ "module": state_dict, "format": "pt" }, "./data/model.safetensors")
def create_inputs( phn, prom, lang=0, seq=None, mode="AR:0:0" ): def create_inputs( phn, prom, lang=0, seq=None, mode="AR:0:0" ):
rvq_l = mode_lvl_map[mode] rvq_l = mode_lvl_map[mode]
@ -1637,7 +1673,9 @@ if __name__ == "__main__":
break break
prom_embd += embds[f"prom|{i}"](p) prom_embd += embds[f"prom|{i}"](p)
if seq is not None: if isinstance( seq, list ) and not seq:
...
elif seq is not None:
if mode == "len": if mode == "len":
seq_embd = embds["len"](seq) seq_embd = embds["len"](seq)
elif mode == "AR:0:0": elif mode == "AR:0:0":
@ -1705,7 +1743,7 @@ if __name__ == "__main__":
# test len inferencing # test len inferencing
print( "len:", generate( phn, prom, mode="len" ) ) #print( "len:", generate( phn, prom, mode="len" ) )
# test ar ouptut # test ar ouptut
if resp: if resp:

View File

@ -450,7 +450,7 @@ class Base_V2(nn.Module):
self.model = None self.model = None
elif self.arch_type in ["llama"]: elif self.arch_type in ["llama"]:
self.model_config = LlamaConfig( self.model_config = LlamaConfig(
vocab_size=n_vocab, vocab_size=0, # n_vocab,
hidden_size=d_model, hidden_size=d_model,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
intermediate_size=d_model*d_ffn, intermediate_size=d_model*d_ffn,