diagnosed both hf/llama.cpp versions to probably just being a faulty export method (to-do: migrate vall_e.models.base to vall_e.export --hf)

This commit is contained in:
mrq 2025-04-05 22:05:39 -05:00
parent c34763769a
commit 1e22519d94
4 changed files with 58 additions and 17 deletions

View File

@ -580,7 +580,7 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
int32_t seq_len = n_outputs;
int32_t top_k = 0;
float top_p = 1.0;
float temperature = 1.5f;
float temperature = 1.0f;
float cfg_strength = 3.0f;
float start_noise = 0.0f;
float end_noise = 1.0f;

View File

@ -573,6 +573,9 @@ class Model(LlamaPreTrainedModel):
self.vocab_size = config.vocab_size
self.layers_n = config.num_hidden_layers
if self.vocab_size:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)

View File

@ -468,7 +468,7 @@ class Base(nn.Module):
self.model_config = LlamaConfig(
vocab_size=n_vocab,
vocab_size=0, # n_vocab,
hidden_size=d_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=d_model*d_ffn,
@ -1435,15 +1435,15 @@ if __name__ == "__main__":
from ..models import download_model, DEFAULT_MODEL_PATH
from ..emb.qnt import decode_to_file
from ..utils.io import torch_load
from ..utils.io import torch_load, torch_save
# hack in a non-causal mask
def _update_noncausal_mask(
attention_mask,
inputs_embeds,
cache_positions,
past_key_values_length,
output_attentions,
past_key_values_length=0,
output_attentions=False,
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
@ -1464,18 +1464,31 @@ if __name__ == "__main__":
device = "cuda"
dtype = torch.bfloat16
# the pretrained model is botched
is_from_pretrained = True
kludge_export = True
if is_from_pretrained:
# tokenizer = LlamaTokenizer.from_pretrained("ecker/vall-e", revision="hf")
kludge_export = False
hf_model = LlamaForCausalLM.from_pretrained("ecker/vall-e", revision="hf")
hf_model.to(device=device, dtype=dtype)
hf_model.eval()
model = hf_model.model
else:
class LlamaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward( *args, **kwargs ):
return self.model( *args, **kwargs )
download_model()
model = LlamaModel(LlamaConfig(
vocab_size=1024,
hf_model = LlamaForCausalLM(LlamaConfig(
vocab_size=17702,
hidden_size=1024,
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
intermediate_size=1024*4,
@ -1483,11 +1496,12 @@ if __name__ == "__main__":
num_attention_heads=16,
attention_dropout=0.0,
num_key_value_heads=16,
sliding_window=75 * 12, # 12 second context window
sliding_window=None, # 75 * 12, # 12 second context window
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
))
model = hf_model.model
state_dict = torch_load(DEFAULT_MODEL_PATH)['module']
state_dict_model = {}
@ -1497,8 +1511,9 @@ if __name__ == "__main__":
state_dict_model[k.replace("model.", "")] = v
model.load_state_dict( state_dict_model, strict=False )
model.to(device=device, dtype=dtype)
model.eval()
hf_model.to(device=device, dtype=dtype)
hf_model.eval()
model._original_update_causal_mask = model._update_causal_mask
model._update_noncausal_mask = _update_noncausal_mask
@ -1531,7 +1546,8 @@ if __name__ == "__main__":
# name, (start, end), classifier, src_name
io_map = {
'text': [(0, 256), 9, "text_emb.weight"],
'phn': [(0, 256), 9, "text_emb.weight"],
#'text': [(0, 256), 9, "raw_text_emb.weight"],
'rvq_l': [(256, 264), -1, "rvq_l_emb.weight"],
'lang': [(264, 270), -1, "langs_emb.weight"],
'task': [(270, 279), -1, "tasks_emb.weight"],
@ -1574,6 +1590,7 @@ if __name__ == "__main__":
heads = {}
n_embd = 1024
f_it = 0
with torch.no_grad():
for k, v in io_map.items():
start, end = v[0]
@ -1583,7 +1600,7 @@ if __name__ == "__main__":
if is_from_pretrained:
n_vocab = end - start
embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
embds[k] = torch.nn.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
embds[k].weight[:] = model.embed_tokens.weight[start:end, :]
if classifier_idx >= 0:
@ -1595,15 +1612,34 @@ if __name__ == "__main__":
heads[k].weight[:] = hf_model.lm_head.weight[start:end, :]
else:
embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name]
embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
embds[k] = torch.nn.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
embds[k].load_state_dict({ "weight": embd_weight })
vocab = embd_weight.shape[0]
start = f_it
f_it += vocab
end = f_it
if kludge_export:
model.embed_tokens.weight[start:end, :] = embds[k].weight
if classifier_idx >= 0:
# NAR:0:0 does not have a masked token output
if k == "resp|NAR:0:0":
end -= 1
head_weight = state_dict[f'classifiers.proj.{classifier_idx}.weight']
heads[k] = torch.nn.Linear( head_weight.shape[1], head_weight.shape[0], bias=False ).to(device=device, dtype=dtype)
heads[k].load_state_dict({ "weight": head_weight })
if kludge_export:
hf_model.lm_head.weight[start:end, :] = heads[k].weight
if kludge_export:
state_dict = hf_model.state_dict()
torch_save({ "module": state_dict, "format": "pt" }, "./data/model.safetensors")
def create_inputs( phn, prom, lang=0, seq=None, mode="AR:0:0" ):
rvq_l = mode_lvl_map[mode]
@ -1637,7 +1673,9 @@ if __name__ == "__main__":
break
prom_embd += embds[f"prom|{i}"](p)
if seq is not None:
if isinstance( seq, list ) and not seq:
...
elif seq is not None:
if mode == "len":
seq_embd = embds["len"](seq)
elif mode == "AR:0:0":
@ -1705,7 +1743,7 @@ if __name__ == "__main__":
# test len inferencing
print( "len:", generate( phn, prom, mode="len" ) )
#print( "len:", generate( phn, prom, mode="len" ) )
# test ar ouptut
if resp:

View File

@ -450,7 +450,7 @@ class Base_V2(nn.Module):
self.model = None
elif self.arch_type in ["llama"]:
self.model_config = LlamaConfig(
vocab_size=n_vocab,
vocab_size=0, # n_vocab,
hidden_size=d_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=d_model*d_ffn,