diagnosed both hf/llama.cpp versions to probably just being a faulty export method (to-do: migrate vall_e.models.base to vall_e.export --hf)
This commit is contained in:
parent
c34763769a
commit
1e22519d94
|
@ -580,7 +580,7 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
|
|||
int32_t seq_len = n_outputs;
|
||||
int32_t top_k = 0;
|
||||
float top_p = 1.0;
|
||||
float temperature = 1.5f;
|
||||
float temperature = 1.0f;
|
||||
float cfg_strength = 3.0f;
|
||||
float start_noise = 0.0f;
|
||||
float end_noise = 1.0f;
|
||||
|
|
|
@ -573,6 +573,9 @@ class Model(LlamaPreTrainedModel):
|
|||
self.vocab_size = config.vocab_size
|
||||
self.layers_n = config.num_hidden_layers
|
||||
|
||||
if self.vocab_size:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
|
|
@ -468,7 +468,7 @@ class Base(nn.Module):
|
|||
|
||||
|
||||
self.model_config = LlamaConfig(
|
||||
vocab_size=n_vocab,
|
||||
vocab_size=0, # n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=d_model*d_ffn,
|
||||
|
@ -1435,15 +1435,15 @@ if __name__ == "__main__":
|
|||
from ..models import download_model, DEFAULT_MODEL_PATH
|
||||
|
||||
from ..emb.qnt import decode_to_file
|
||||
from ..utils.io import torch_load
|
||||
from ..utils.io import torch_load, torch_save
|
||||
|
||||
# hack in a non-causal mask
|
||||
def _update_noncausal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_positions,
|
||||
past_key_values_length,
|
||||
output_attentions,
|
||||
past_key_values_length=0,
|
||||
output_attentions=False,
|
||||
):
|
||||
# create noncausal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
|
@ -1464,18 +1464,31 @@ if __name__ == "__main__":
|
|||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# the pretrained model is botched
|
||||
is_from_pretrained = True
|
||||
kludge_export = True
|
||||
if is_from_pretrained:
|
||||
# tokenizer = LlamaTokenizer.from_pretrained("ecker/vall-e", revision="hf")
|
||||
kludge_export = False
|
||||
hf_model = LlamaForCausalLM.from_pretrained("ecker/vall-e", revision="hf")
|
||||
hf_model.to(device=device, dtype=dtype)
|
||||
hf_model.eval()
|
||||
|
||||
model = hf_model.model
|
||||
else:
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.model = LlamaModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def forward( *args, **kwargs ):
|
||||
return self.model( *args, **kwargs )
|
||||
|
||||
download_model()
|
||||
model = LlamaModel(LlamaConfig(
|
||||
vocab_size=1024,
|
||||
|
||||
hf_model = LlamaForCausalLM(LlamaConfig(
|
||||
vocab_size=17702,
|
||||
hidden_size=1024,
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
intermediate_size=1024*4,
|
||||
|
@ -1483,11 +1496,12 @@ if __name__ == "__main__":
|
|||
num_attention_heads=16,
|
||||
attention_dropout=0.0,
|
||||
num_key_value_heads=16,
|
||||
sliding_window=75 * 12, # 12 second context window
|
||||
sliding_window=None, # 75 * 12, # 12 second context window
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
))
|
||||
model = hf_model.model
|
||||
|
||||
state_dict = torch_load(DEFAULT_MODEL_PATH)['module']
|
||||
state_dict_model = {}
|
||||
|
@ -1497,8 +1511,9 @@ if __name__ == "__main__":
|
|||
state_dict_model[k.replace("model.", "")] = v
|
||||
|
||||
model.load_state_dict( state_dict_model, strict=False )
|
||||
model.to(device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
hf_model.to(device=device, dtype=dtype)
|
||||
hf_model.eval()
|
||||
|
||||
model._original_update_causal_mask = model._update_causal_mask
|
||||
model._update_noncausal_mask = _update_noncausal_mask
|
||||
|
@ -1531,7 +1546,8 @@ if __name__ == "__main__":
|
|||
|
||||
# name, (start, end), classifier, src_name
|
||||
io_map = {
|
||||
'text': [(0, 256), 9, "text_emb.weight"],
|
||||
'phn': [(0, 256), 9, "text_emb.weight"],
|
||||
#'text': [(0, 256), 9, "raw_text_emb.weight"],
|
||||
'rvq_l': [(256, 264), -1, "rvq_l_emb.weight"],
|
||||
'lang': [(264, 270), -1, "langs_emb.weight"],
|
||||
'task': [(270, 279), -1, "tasks_emb.weight"],
|
||||
|
@ -1574,6 +1590,7 @@ if __name__ == "__main__":
|
|||
heads = {}
|
||||
n_embd = 1024
|
||||
|
||||
f_it = 0
|
||||
with torch.no_grad():
|
||||
for k, v in io_map.items():
|
||||
start, end = v[0]
|
||||
|
@ -1583,7 +1600,7 @@ if __name__ == "__main__":
|
|||
if is_from_pretrained:
|
||||
n_vocab = end - start
|
||||
|
||||
embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
|
||||
embds[k] = torch.nn.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
|
||||
embds[k].weight[:] = model.embed_tokens.weight[start:end, :]
|
||||
|
||||
if classifier_idx >= 0:
|
||||
|
@ -1595,15 +1612,34 @@ if __name__ == "__main__":
|
|||
heads[k].weight[:] = hf_model.lm_head.weight[start:end, :]
|
||||
else:
|
||||
embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name]
|
||||
embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
|
||||
embds[k] = torch.nn.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
|
||||
embds[k].load_state_dict({ "weight": embd_weight })
|
||||
|
||||
vocab = embd_weight.shape[0]
|
||||
start = f_it
|
||||
f_it += vocab
|
||||
end = f_it
|
||||
|
||||
if kludge_export:
|
||||
model.embed_tokens.weight[start:end, :] = embds[k].weight
|
||||
|
||||
if classifier_idx >= 0:
|
||||
# NAR:0:0 does not have a masked token output
|
||||
if k == "resp|NAR:0:0":
|
||||
end -= 1
|
||||
|
||||
head_weight = state_dict[f'classifiers.proj.{classifier_idx}.weight']
|
||||
|
||||
heads[k] = torch.nn.Linear( head_weight.shape[1], head_weight.shape[0], bias=False ).to(device=device, dtype=dtype)
|
||||
heads[k].load_state_dict({ "weight": head_weight })
|
||||
|
||||
if kludge_export:
|
||||
hf_model.lm_head.weight[start:end, :] = heads[k].weight
|
||||
|
||||
if kludge_export:
|
||||
state_dict = hf_model.state_dict()
|
||||
torch_save({ "module": state_dict, "format": "pt" }, "./data/model.safetensors")
|
||||
|
||||
def create_inputs( phn, prom, lang=0, seq=None, mode="AR:0:0" ):
|
||||
rvq_l = mode_lvl_map[mode]
|
||||
|
||||
|
@ -1637,7 +1673,9 @@ if __name__ == "__main__":
|
|||
break
|
||||
prom_embd += embds[f"prom|{i}"](p)
|
||||
|
||||
if seq is not None:
|
||||
if isinstance( seq, list ) and not seq:
|
||||
...
|
||||
elif seq is not None:
|
||||
if mode == "len":
|
||||
seq_embd = embds["len"](seq)
|
||||
elif mode == "AR:0:0":
|
||||
|
@ -1705,7 +1743,7 @@ if __name__ == "__main__":
|
|||
|
||||
|
||||
# test len inferencing
|
||||
print( "len:", generate( phn, prom, mode="len" ) )
|
||||
#print( "len:", generate( phn, prom, mode="len" ) )
|
||||
|
||||
# test ar ouptut
|
||||
if resp:
|
||||
|
|
|
@ -450,7 +450,7 @@ class Base_V2(nn.Module):
|
|||
self.model = None
|
||||
elif self.arch_type in ["llama"]:
|
||||
self.model_config = LlamaConfig(
|
||||
vocab_size=n_vocab,
|
||||
vocab_size=0, # n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=d_model*d_ffn,
|
||||
|
|
Loading…
Reference in New Issue
Block a user