diagnosed both hf/llama.cpp versions to probably just being a faulty export method (to-do: migrate vall_e.models.base to vall_e.export --hf)
This commit is contained in:
parent
c34763769a
commit
1e22519d94
|
@ -580,7 +580,7 @@ std::vector<token_t> generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, i
|
||||||
int32_t seq_len = n_outputs;
|
int32_t seq_len = n_outputs;
|
||||||
int32_t top_k = 0;
|
int32_t top_k = 0;
|
||||||
float top_p = 1.0;
|
float top_p = 1.0;
|
||||||
float temperature = 1.5f;
|
float temperature = 1.0f;
|
||||||
float cfg_strength = 3.0f;
|
float cfg_strength = 3.0f;
|
||||||
float start_noise = 0.0f;
|
float start_noise = 0.0f;
|
||||||
float end_noise = 1.0f;
|
float end_noise = 1.0f;
|
||||||
|
|
|
@ -573,6 +573,9 @@ class Model(LlamaPreTrainedModel):
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.layers_n = config.num_hidden_layers
|
self.layers_n = config.num_hidden_layers
|
||||||
|
|
||||||
|
if self.vocab_size:
|
||||||
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
[DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
)
|
)
|
||||||
|
|
|
@ -468,7 +468,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
self.model_config = LlamaConfig(
|
self.model_config = LlamaConfig(
|
||||||
vocab_size=n_vocab,
|
vocab_size=0, # n_vocab,
|
||||||
hidden_size=d_model,
|
hidden_size=d_model,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
intermediate_size=d_model*d_ffn,
|
intermediate_size=d_model*d_ffn,
|
||||||
|
@ -1435,15 +1435,15 @@ if __name__ == "__main__":
|
||||||
from ..models import download_model, DEFAULT_MODEL_PATH
|
from ..models import download_model, DEFAULT_MODEL_PATH
|
||||||
|
|
||||||
from ..emb.qnt import decode_to_file
|
from ..emb.qnt import decode_to_file
|
||||||
from ..utils.io import torch_load
|
from ..utils.io import torch_load, torch_save
|
||||||
|
|
||||||
# hack in a non-causal mask
|
# hack in a non-causal mask
|
||||||
def _update_noncausal_mask(
|
def _update_noncausal_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
cache_positions,
|
cache_positions,
|
||||||
past_key_values_length,
|
past_key_values_length=0,
|
||||||
output_attentions,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
# create noncausal mask
|
# create noncausal mask
|
||||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||||
|
@ -1464,18 +1464,31 @@ if __name__ == "__main__":
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# the pretrained model is botched
|
||||||
is_from_pretrained = True
|
is_from_pretrained = True
|
||||||
|
kludge_export = True
|
||||||
if is_from_pretrained:
|
if is_from_pretrained:
|
||||||
# tokenizer = LlamaTokenizer.from_pretrained("ecker/vall-e", revision="hf")
|
kludge_export = False
|
||||||
hf_model = LlamaForCausalLM.from_pretrained("ecker/vall-e", revision="hf")
|
hf_model = LlamaForCausalLM.from_pretrained("ecker/vall-e", revision="hf")
|
||||||
hf_model.to(device=device, dtype=dtype)
|
hf_model.to(device=device, dtype=dtype)
|
||||||
hf_model.eval()
|
hf_model.eval()
|
||||||
|
|
||||||
model = hf_model.model
|
model = hf_model.model
|
||||||
else:
|
else:
|
||||||
|
class LlamaForCausalLM(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = LlamaModel(config)
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def forward( *args, **kwargs ):
|
||||||
|
return self.model( *args, **kwargs )
|
||||||
|
|
||||||
download_model()
|
download_model()
|
||||||
model = LlamaModel(LlamaConfig(
|
|
||||||
vocab_size=1024,
|
hf_model = LlamaForCausalLM(LlamaConfig(
|
||||||
|
vocab_size=17702,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||||
intermediate_size=1024*4,
|
intermediate_size=1024*4,
|
||||||
|
@ -1483,11 +1496,12 @@ if __name__ == "__main__":
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
num_key_value_heads=16,
|
num_key_value_heads=16,
|
||||||
sliding_window=75 * 12, # 12 second context window
|
sliding_window=None, # 75 * 12, # 12 second context window
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
))
|
))
|
||||||
|
model = hf_model.model
|
||||||
|
|
||||||
state_dict = torch_load(DEFAULT_MODEL_PATH)['module']
|
state_dict = torch_load(DEFAULT_MODEL_PATH)['module']
|
||||||
state_dict_model = {}
|
state_dict_model = {}
|
||||||
|
@ -1497,8 +1511,9 @@ if __name__ == "__main__":
|
||||||
state_dict_model[k.replace("model.", "")] = v
|
state_dict_model[k.replace("model.", "")] = v
|
||||||
|
|
||||||
model.load_state_dict( state_dict_model, strict=False )
|
model.load_state_dict( state_dict_model, strict=False )
|
||||||
model.to(device=device, dtype=dtype)
|
|
||||||
model.eval()
|
hf_model.to(device=device, dtype=dtype)
|
||||||
|
hf_model.eval()
|
||||||
|
|
||||||
model._original_update_causal_mask = model._update_causal_mask
|
model._original_update_causal_mask = model._update_causal_mask
|
||||||
model._update_noncausal_mask = _update_noncausal_mask
|
model._update_noncausal_mask = _update_noncausal_mask
|
||||||
|
@ -1531,7 +1546,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# name, (start, end), classifier, src_name
|
# name, (start, end), classifier, src_name
|
||||||
io_map = {
|
io_map = {
|
||||||
'text': [(0, 256), 9, "text_emb.weight"],
|
'phn': [(0, 256), 9, "text_emb.weight"],
|
||||||
|
#'text': [(0, 256), 9, "raw_text_emb.weight"],
|
||||||
'rvq_l': [(256, 264), -1, "rvq_l_emb.weight"],
|
'rvq_l': [(256, 264), -1, "rvq_l_emb.weight"],
|
||||||
'lang': [(264, 270), -1, "langs_emb.weight"],
|
'lang': [(264, 270), -1, "langs_emb.weight"],
|
||||||
'task': [(270, 279), -1, "tasks_emb.weight"],
|
'task': [(270, 279), -1, "tasks_emb.weight"],
|
||||||
|
@ -1574,6 +1590,7 @@ if __name__ == "__main__":
|
||||||
heads = {}
|
heads = {}
|
||||||
n_embd = 1024
|
n_embd = 1024
|
||||||
|
|
||||||
|
f_it = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for k, v in io_map.items():
|
for k, v in io_map.items():
|
||||||
start, end = v[0]
|
start, end = v[0]
|
||||||
|
@ -1583,7 +1600,7 @@ if __name__ == "__main__":
|
||||||
if is_from_pretrained:
|
if is_from_pretrained:
|
||||||
n_vocab = end - start
|
n_vocab = end - start
|
||||||
|
|
||||||
embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
|
embds[k] = torch.nn.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
|
||||||
embds[k].weight[:] = model.embed_tokens.weight[start:end, :]
|
embds[k].weight[:] = model.embed_tokens.weight[start:end, :]
|
||||||
|
|
||||||
if classifier_idx >= 0:
|
if classifier_idx >= 0:
|
||||||
|
@ -1595,15 +1612,34 @@ if __name__ == "__main__":
|
||||||
heads[k].weight[:] = hf_model.lm_head.weight[start:end, :]
|
heads[k].weight[:] = hf_model.lm_head.weight[start:end, :]
|
||||||
else:
|
else:
|
||||||
embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name]
|
embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name]
|
||||||
embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
|
embds[k] = torch.nn.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
|
||||||
embds[k].load_state_dict({ "weight": embd_weight })
|
embds[k].load_state_dict({ "weight": embd_weight })
|
||||||
|
|
||||||
|
vocab = embd_weight.shape[0]
|
||||||
|
start = f_it
|
||||||
|
f_it += vocab
|
||||||
|
end = f_it
|
||||||
|
|
||||||
|
if kludge_export:
|
||||||
|
model.embed_tokens.weight[start:end, :] = embds[k].weight
|
||||||
|
|
||||||
if classifier_idx >= 0:
|
if classifier_idx >= 0:
|
||||||
|
# NAR:0:0 does not have a masked token output
|
||||||
|
if k == "resp|NAR:0:0":
|
||||||
|
end -= 1
|
||||||
|
|
||||||
head_weight = state_dict[f'classifiers.proj.{classifier_idx}.weight']
|
head_weight = state_dict[f'classifiers.proj.{classifier_idx}.weight']
|
||||||
|
|
||||||
heads[k] = torch.nn.Linear( head_weight.shape[1], head_weight.shape[0], bias=False ).to(device=device, dtype=dtype)
|
heads[k] = torch.nn.Linear( head_weight.shape[1], head_weight.shape[0], bias=False ).to(device=device, dtype=dtype)
|
||||||
heads[k].load_state_dict({ "weight": head_weight })
|
heads[k].load_state_dict({ "weight": head_weight })
|
||||||
|
|
||||||
|
if kludge_export:
|
||||||
|
hf_model.lm_head.weight[start:end, :] = heads[k].weight
|
||||||
|
|
||||||
|
if kludge_export:
|
||||||
|
state_dict = hf_model.state_dict()
|
||||||
|
torch_save({ "module": state_dict, "format": "pt" }, "./data/model.safetensors")
|
||||||
|
|
||||||
def create_inputs( phn, prom, lang=0, seq=None, mode="AR:0:0" ):
|
def create_inputs( phn, prom, lang=0, seq=None, mode="AR:0:0" ):
|
||||||
rvq_l = mode_lvl_map[mode]
|
rvq_l = mode_lvl_map[mode]
|
||||||
|
|
||||||
|
@ -1637,7 +1673,9 @@ if __name__ == "__main__":
|
||||||
break
|
break
|
||||||
prom_embd += embds[f"prom|{i}"](p)
|
prom_embd += embds[f"prom|{i}"](p)
|
||||||
|
|
||||||
if seq is not None:
|
if isinstance( seq, list ) and not seq:
|
||||||
|
...
|
||||||
|
elif seq is not None:
|
||||||
if mode == "len":
|
if mode == "len":
|
||||||
seq_embd = embds["len"](seq)
|
seq_embd = embds["len"](seq)
|
||||||
elif mode == "AR:0:0":
|
elif mode == "AR:0:0":
|
||||||
|
@ -1705,7 +1743,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
||||||
# test len inferencing
|
# test len inferencing
|
||||||
print( "len:", generate( phn, prom, mode="len" ) )
|
#print( "len:", generate( phn, prom, mode="len" ) )
|
||||||
|
|
||||||
# test ar ouptut
|
# test ar ouptut
|
||||||
if resp:
|
if resp:
|
||||||
|
|
|
@ -450,7 +450,7 @@ class Base_V2(nn.Module):
|
||||||
self.model = None
|
self.model = None
|
||||||
elif self.arch_type in ["llama"]:
|
elif self.arch_type in ["llama"]:
|
||||||
self.model_config = LlamaConfig(
|
self.model_config = LlamaConfig(
|
||||||
vocab_size=n_vocab,
|
vocab_size=0, # n_vocab,
|
||||||
hidden_size=d_model,
|
hidden_size=d_model,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
intermediate_size=d_model*d_ffn,
|
intermediate_size=d_model*d_ffn,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user