agony
This commit is contained in:
parent
2542ed067d
commit
353e478e68
52
scripts/hf_test.py
Normal file
52
scripts/hf_test.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
# tokenizer = LlamaTokenizer.from_pretrained("./training/llama-encodec-ar+nar-len/hf/")
|
||||
model = LlamaForCausalLM.from_pretrained("./training/llama-encodec-ar+nar-len/hf/")
|
||||
|
||||
phns = [1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2]
|
||||
proms = [
|
||||
[780,835,835,835,339,395,798,537,537,537,537,222,76,989,548,65,705,375,261,375,297,503,529,571,707,346,464,862,148,496,574,115,115,438,934,339,865,876,63,40,779,461,602,794,10,220,398,869,639,705,869,917,705,893,215,705,869,938,439,175,139,506,375,529,297,705,651,238,962,461,195,441,377,581,473,795,644,626,459,981,767,670,696,73,779,257,408,1017,1019,133,133,1017,835,604,699,626,67,92,707,92,179,179,772,869,441,799,917,238,745,904,904,904,106,133,1019,1017,1017,395,883,87,519,594,1002,682,996,540,186,1019,430,202,347,889,61,92,542,297,67,669,571,707,346,67,359,571,707,669,604,25,1008,810,35,621,67,600,333,123,284,568,817,243,778,464,638,610,359,538,464,975,321,700,377,484,179,284,284,621,538,464,745,171,171,159,744,159,287,461,69,15,529,67,92,669,464,515,605,24,822,865,293,62,172,638,359,562,138,839,846,775,556,688,1006,917,297,312,148,331,496,646,67,314,15,705,131,855,662,287,172,85,538,519,762,450,391,609,643,778,80,287,794,794,115,785,794,461,699,519,932,522,652,262,508,902,932,932,391,769,18,507,90,442,762,610,610,669,605,310,855,56,989,863,195,464,604,257,904,632,786,951,461,239,195,878,771,146,481,146,481,434,643,917,280,67,464,115,744,744,115,115,115,819,709,63,368,359,519,996,616,464,996,616,519,762,917,841,772,568,954,600,422,893,592,464,626,86,143,615,171,744,744,196,115,821,415,521,799,654,839,644,473,592,953,523,855,738,855,876,876,1017,63,329]
|
||||
]
|
||||
sep = [17685]
|
||||
rvq_lvl = [17666]
|
||||
lang = [17686]
|
||||
len_seq = [17674]
|
||||
|
||||
for i, t in enumerate( proms[0] ):
|
||||
proms[0][i] = t + 256 + 1024
|
||||
|
||||
ids = torch.tensor(phns + sep + lang + sep + rvq_lvl + sep + proms[0] + sep + len_seq, device="cuda", dtype=torch.int32)
|
||||
pos_ids = torch.tensor( [*range(len(phns)+1)] + [*range(2)] + [*range(2)] + [*range(len(proms[0])+1)] + [0], device="cuda", dtype=torch.int32)
|
||||
|
||||
start = 17674 # 8448
|
||||
end = start + 10 # 1025
|
||||
|
||||
with torch.no_grad():
|
||||
original_lm_head = model.lm_head.weight
|
||||
|
||||
model.lm_head = torch.nn.Linear(1024, end - start, bias=False)
|
||||
model.lm_head.weight.copy_(original_lm_head[start:end])
|
||||
|
||||
model.to(device="cuda", dtype=torch.float16)
|
||||
model.eval()
|
||||
|
||||
n_decoded = 0
|
||||
while True:
|
||||
out = model(input_ids=ids.unsqueeze(0), position_ids=pos_ids.unsqueeze(0))
|
||||
|
||||
#logits = out.logits[0, -1:, start:end]
|
||||
logits = out.logits[0, -1:, :]
|
||||
tokens = logits.argmax(dim=-1)
|
||||
n_decoded += 1
|
||||
|
||||
print( n_decoded, tokens )
|
||||
|
||||
if end in tokens or n_decoded > 5:
|
||||
break
|
||||
|
||||
ids = torch.concat( [ ids, tokens + start ] )
|
||||
pos_ids = torch.concat( [ pos_ids, torch.tensor([n_decoded]).to(pos_ids) ] )
|
||||
|
||||
print( out )
|
||||
print( ids )
|
|
@ -10,7 +10,9 @@ Populate `./include/` with the `llama.cpp` and `encodec.cpp` headers.
|
|||
|
||||
Populate `./libs/` with the compiled libraries of `llama.cpp` and `encodec.cpp`.
|
||||
* `encodec.cpp` requires updating `ggml` to the latest version and doing a quick hack to make it work on the CPU backend.
|
||||
* `llama.cpp` currently requires no hacks, but would be *very* nice to hack in a way to retrieve a model's `tok_embd`.
|
||||
* `llama.cpp` currently requires no hacks, but:
|
||||
* would be *very* nice to retrieve a model's `tok_embd` through the API.
|
||||
* would be ***very*** nice to only specify a slice of the output head through the API.
|
||||
|
||||
Run `make`.
|
||||
|
||||
|
|
|
@ -250,18 +250,19 @@ std::vector<float> decode_audio( struct encodec_context* ectx, const std::vector
|
|||
}
|
||||
|
||||
const int EMBEDDING_MODE_PROM = 0;
|
||||
const int EMBEDDING_MODE_RESP_AR_NAR = 0;
|
||||
const int EMBEDDING_MODE_RESP_NAR_LEN = 0;
|
||||
const int EMBEDDING_MODE_RESP_AR_NAR = 1;
|
||||
const int EMBEDDING_MODE_RESP_NAR_LEN = 2;
|
||||
|
||||
const int INFERENCE_MODE_LEN = 0;
|
||||
const int INFERENCE_MODE_AR = 1;
|
||||
const int INFERENCE_MODE_NAR_DEMASK = 2;
|
||||
const int INFERENCE_MODE_NAR = 4;
|
||||
const int INFERENCE_MODE_NAR = 3;
|
||||
|
||||
const int MODALITY_AR_NAR = 0;
|
||||
const int MODALITY_NAR_LEN = 0;
|
||||
const int MODALITY_NAR_LEN = 1;
|
||||
|
||||
const int MAX_DURATION = 75; // * 12;
|
||||
const int CTX_SIZE = 2048;
|
||||
|
||||
// sums embeddings over a 2D "tensor"
|
||||
std::vector<std::vector<float>> sum_embeddings( const std::vector<std::vector<llama_token>>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) {
|
||||
|
@ -457,14 +458,14 @@ std::vector<llama_token> generate( llama_context* ctx, llama_model* model, llama
|
|||
int main(int argc, char ** argv) {
|
||||
// to-do: replace all of this with proper loading code
|
||||
int32_t ngl = 0;
|
||||
int modality = MODALITY_AR_NAR;
|
||||
int modality = MODALITY_NAR_LEN;
|
||||
input_t input{};
|
||||
embeddings_t embeddings_map{};
|
||||
|
||||
// input.phonemes = "hˈɛloː ʋˈɔrlt";
|
||||
input.phn = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // <bos>hˈɛloː ʋˈɔrlt</eos>
|
||||
|
||||
std::string vall_e_model_path = "./data/vall_e-F16.gguf";
|
||||
std::string vall_e_model_path = "./data/vall_e-f16.gguf";
|
||||
std::string encodec_model_path = "./data/encodec.bin";
|
||||
std::string input_prompt_path = "./data/prom.wav";
|
||||
std::string output_response_path = "./data/resp.wav";
|
||||
|
@ -497,9 +498,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// initialize the context
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = 22500;
|
||||
ctx_params.n_batch = 22500;
|
||||
ctx_params.n_ubatch = 22500;
|
||||
ctx_params.n_ctx = CTX_SIZE;
|
||||
ctx_params.n_batch = CTX_SIZE;
|
||||
ctx_params.n_ubatch = CTX_SIZE;
|
||||
ctx_params.no_perf = false;
|
||||
ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL;
|
||||
|
||||
|
@ -519,7 +520,7 @@ int main(int argc, char ** argv) {
|
|||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_k(20));
|
||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_p(0.9, 20));
|
||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_temp (1.0));
|
||||
// llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130));
|
||||
llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130));
|
||||
|
||||
llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy());
|
||||
|
||||
|
@ -542,13 +543,13 @@ int main(int argc, char ** argv) {
|
|||
if ( input.phonemes != "" ) {
|
||||
const int n_prompt = -llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), NULL, 0, true, true);
|
||||
// allocate space for the tokens and tokenize the input.phonemes
|
||||
input.phns.resize(n_prompt)
|
||||
if (llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), input.phns.data(), input.phns.size(), true, true) < 0) {
|
||||
input.phn.resize(n_prompt);
|
||||
if (llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), input.phn.data(), input.phn.size(), true, true) < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize: %s\n", __func__, input.phonemes.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
for ( auto& token : input.phns ) printf("%i ", token );
|
||||
for ( auto& token : input.phn ) printf("%i ", token );
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
|
|
|
@ -114,6 +114,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
|
||||
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
params["betas"] = (0.9, 0.96)
|
||||
params["eps"] = 1e-07
|
||||
|
|
|
@ -71,20 +71,25 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
"stt",
|
||||
]
|
||||
|
||||
classifier_bias = False
|
||||
classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias
|
||||
split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers
|
||||
|
||||
embedding = torch.nn.Embedding( n_tokens, model_dim )
|
||||
classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias )
|
||||
|
||||
if not split_classifiers:
|
||||
classifier.weight[:] = state_dict['module']['classifier.weight'][:]
|
||||
|
||||
# to-do: ignore classifier for RVQ level 7
|
||||
|
||||
# inject text tokens
|
||||
token_start = 0
|
||||
token_end = l_tokens[0]
|
||||
embedding.weight[token_start:token_end] = state_dict['module']['text_emb.weight']
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias']
|
||||
if split_classifiers:
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias']
|
||||
# tokenizer already has these tokens
|
||||
|
||||
# inject prom tokens
|
||||
|
@ -104,9 +109,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
token_start = token_end
|
||||
token_end += l_tokens[2] // 2
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.0.weight']
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
|
||||
if split_classifiers:
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_vocab[f'<|AR|0:0|{t}|>'] = token_start + t
|
||||
tokenizer_vocab[f'<AR|0:0|STOP|>'] = token_start + 1024
|
||||
|
@ -115,9 +121,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
token_start = token_end
|
||||
token_end += l_tokens[2] // 2
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight']
|
||||
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
|
||||
if split_classifiers:
|
||||
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_vocab[f'<|NAR|0:0|{t}|>'] = token_start + t
|
||||
tokenizer_vocab[f'<|NAR|0:0|STOP|>'] = token_start + 1024
|
||||
|
@ -129,9 +136,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
start = token_start + ((l-1) * n_audio_tokens)
|
||||
end = start + n_audio_tokens
|
||||
embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight']
|
||||
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
||||
if split_classifiers:
|
||||
classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight']
|
||||
if classifier_bias:
|
||||
classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_vocab[f'<|NAR|{l-1}:{l}|{t}|>'] = start + t
|
||||
|
||||
|
@ -147,9 +155,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
token_start = token_end
|
||||
token_end += l_tokens[5]
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'len_emb.weight']
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
|
||||
if split_classifiers:
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256
|
||||
if classifier_bias:
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256
|
||||
for t in range(n_len_tokens):
|
||||
tokenizer_vocab[f'<|len:{t}|>'] = token_start + t
|
||||
|
||||
|
@ -197,7 +206,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
out_dir = cfg.rel_path / "hf"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# write weights
|
||||
torch_save( model_dict, out_dir / "model.safetensors" )
|
||||
torch_save( { "module": model_dict, "format": "pt" }, out_dir / "model.safetensors" )
|
||||
# write tokenizer.json
|
||||
tokenizer['model']['vocab'] |= tokenizer_vocab
|
||||
json_write(tokenizer, out_dir / "tokenizer.json", pretty=True)
|
||||
|
|
|
@ -55,6 +55,38 @@ task_outputs = {
|
|||
"len": "len",
|
||||
}
|
||||
|
||||
# yuck
|
||||
def _get_offsets():
|
||||
return {
|
||||
"text": 0, # <unk>
|
||||
"quant_level": 17666, # <|RVQ:0>
|
||||
"len": 17674, # <|len:0|>
|
||||
"lang": 17686, # <|lang:en|>"
|
||||
"task": 17692, # <|task:tts|>
|
||||
"sep": 17685, # <|sep|>
|
||||
"prom": [
|
||||
256 + (1024 * 0), # <|P|0:0|>
|
||||
256 + (1024 * 1), # <|P|1:0|>
|
||||
256 + (1024 * 2), # <|P|2:0|>
|
||||
256 + (1024 * 3), # <|P|3:0|>
|
||||
256 + (1024 * 4), # <|P|4:0|>
|
||||
256 + (1024 * 5), # <|P|5:0|>
|
||||
256 + (1024 * 6), # <|P|6:0|>
|
||||
256 + (1024 * 7), # <|P|7:0|>
|
||||
],
|
||||
"resp": [
|
||||
8448, # <|AR|0:0|>
|
||||
9473, # <|NAR|0:0|>
|
||||
10498 + (1024 * 0), # <|NAR|0:1|>
|
||||
10498 + (1024 * 1), # <|NAR|1:2|>
|
||||
10498 + (1024 * 2), # <|NAR|2:3|>
|
||||
10498 + (1024 * 3), # <|NAR|3:4|>
|
||||
10498 + (1024 * 4), # <|NAR|4:5|>
|
||||
10498 + (1024 * 5), # <|NAR|5:6|>
|
||||
10498 + (1024 * 6), # <|NAR|6:7|>
|
||||
]
|
||||
}
|
||||
|
||||
def _dropout_mask( input, p=None ):
|
||||
# cosine scheduling
|
||||
if p is None:
|
||||
|
@ -494,6 +526,9 @@ class Base(nn.Module):
|
|||
classifier_l_tokens += [ 11 ]
|
||||
classifier_l_names += ["len"]
|
||||
|
||||
n_vocab = 17701 if not split_classifiers else n_resp_tokens + 1
|
||||
|
||||
self.n_vocab = n_vocab
|
||||
self.unified_position_ids = unified_position_ids
|
||||
self.interleave = interleave
|
||||
self.layerskip = layerskip
|
||||
|
@ -601,7 +636,7 @@ class Base(nn.Module):
|
|||
elif self.arch_type in ["mistral", "mixtral"]:
|
||||
if n_experts <= 1:
|
||||
self.model = MistralModel(MistralConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
|
@ -647,7 +682,7 @@ class Base(nn.Module):
|
|||
|
||||
if n_experts <= 1:
|
||||
self.model = LlamaClass(LlamaConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
|
@ -700,7 +735,7 @@ class Base(nn.Module):
|
|||
))
|
||||
elif self.arch_type == "retnet":
|
||||
kwargs = dict(
|
||||
vocab_size=n_resp_tokens,
|
||||
vocab_size=n_vocab,
|
||||
decoder_embed_dim=d_model,
|
||||
decoder_value_embed_dim =d_model * 2,
|
||||
decoder_retention_heads=n_heads,
|
||||
|
@ -732,7 +767,7 @@ class Base(nn.Module):
|
|||
self.model = RetNetDecoder(RetNetConfig(**kwargs))
|
||||
elif self.arch_type in ["mamba2"]:
|
||||
self.model = Mamba2Model(Mamba2Config(
|
||||
vocab_size=n_resp_tokens,
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
expand=2,
|
||||
num_hidden_layers=n_layers*2,
|
||||
|
@ -744,7 +779,7 @@ class Base(nn.Module):
|
|||
))
|
||||
elif self.arch_type in ["mamba"]:
|
||||
self.model = MambaModel(MambaConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
expand=2,
|
||||
num_hidden_layers=n_layers*2,
|
||||
|
@ -761,11 +796,11 @@ class Base(nn.Module):
|
|||
del self.model.embeddings
|
||||
|
||||
if not split_classifiers:
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias)
|
||||
self.classifiers = None
|
||||
|
||||
self.accuracy_metric = MulticlassAccuracy(
|
||||
n_resp_tokens,
|
||||
n_vocab,
|
||||
top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
|
@ -773,7 +808,7 @@ class Base(nn.Module):
|
|||
)
|
||||
|
||||
self.precision_metric = MulticlassPrecision(
|
||||
n_resp_tokens,
|
||||
n_vocab,
|
||||
top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
|
@ -1031,6 +1066,48 @@ class Base(nn.Module):
|
|||
raise Exception(f'Unrecognized task: {task_type}')
|
||||
return inputs
|
||||
|
||||
def offset_inputs(
|
||||
self,
|
||||
inputs: list,
|
||||
direction: int = 1, # -1 to de-offset
|
||||
):
|
||||
offsets = _get_offsets()
|
||||
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
quant_level = None
|
||||
classifier_level = None
|
||||
# pre-iterate
|
||||
for name, input in batch_input:
|
||||
if name == "quant_level":
|
||||
quant_level = input
|
||||
elif name == "classifier_level":
|
||||
classifier_level = input
|
||||
|
||||
for name, input in batch_input:
|
||||
if name not in offsets:
|
||||
continue
|
||||
|
||||
if not isinstance( input, torch.Tensor ):
|
||||
continue
|
||||
|
||||
offset = offsets[name]
|
||||
if name in ["prom", "resp"]:
|
||||
l = quant_level
|
||||
if name == "resp":
|
||||
if classifier_level == "AR:0:0":
|
||||
l = 0
|
||||
elif classifier_level == "NAR:0:0":
|
||||
l = 1
|
||||
else:
|
||||
l = 2 + (quant_level-1)
|
||||
|
||||
offset = offset[l]
|
||||
|
||||
for i, t in enumerate( input ):
|
||||
input[i] += offset * direction
|
||||
|
||||
return inputs
|
||||
|
||||
def inputs_to_embeddings(
|
||||
self,
|
||||
inputs: list,
|
||||
|
@ -1366,6 +1443,49 @@ class Base(nn.Module):
|
|||
if not isinstance(token, torch.Tensor):
|
||||
continue
|
||||
|
||||
# offset to flattened vocab ranges
|
||||
if self.classifier is not None:
|
||||
offsets = _get_offsets()
|
||||
if name in offsets:
|
||||
offset = offsets[name]
|
||||
# yes there's a better way
|
||||
if name == "prom":
|
||||
offset = offset[quant_level]
|
||||
elif name == "resp":
|
||||
"""
|
||||
if classifier_level == "AR:0:0":
|
||||
offset = offset[0]
|
||||
elif classifier_level == "NAR:0:0":
|
||||
offset = offset[1]
|
||||
elif classifier_level == "NAR:0:1":
|
||||
offset = offset[2]
|
||||
elif classifier_level == "NAR:1:2":
|
||||
offset = offset[3]
|
||||
elif classifier_level == "NAR:2:3":
|
||||
offset = offset[4]
|
||||
elif classifier_level == "NAR:3:4":
|
||||
offset = offset[5]
|
||||
elif classifier_level == "NAR:4:5":
|
||||
offset = offset[6]
|
||||
elif classifier_level == "NAR:5:6":
|
||||
offset = offset[7]
|
||||
elif classifier_level == "NAR:6:7":
|
||||
offset = offset[8]
|
||||
else:
|
||||
continue
|
||||
"""
|
||||
if classifier_level == "AR:0:0":
|
||||
offset = offset[0]
|
||||
elif classifier_level == "NAR:0:0":
|
||||
offset = offset[1]
|
||||
else:
|
||||
offset = offset[2 + (quant_level-1)]
|
||||
|
||||
for i, t in enumerate( token ):
|
||||
if t == self.ignore_index:
|
||||
continue
|
||||
token[i] += offset
|
||||
|
||||
if token.is_floating_point():
|
||||
ignored = True
|
||||
|
||||
|
@ -1422,7 +1542,7 @@ class Base(nn.Module):
|
|||
|
||||
# perofrm loss calculation on the entire sequence
|
||||
if not self.config.loss_factors:
|
||||
target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
target = _join( target, torch.tensor(self.ignore_index if self.classifier is None else 17685, device=target[-1].device) )
|
||||
logit = logits[batch_index]
|
||||
|
||||
# shift if causal
|
||||
|
@ -1606,6 +1726,37 @@ class Base(nn.Module):
|
|||
|
||||
self.loss = None
|
||||
self.stats = None
|
||||
|
||||
# de-offset if needed
|
||||
if self.classifier is not None:
|
||||
offsets = _get_offsets()
|
||||
for batch_index, classifier_level in enumerate( classifier_levels ):
|
||||
# yes there's a better way
|
||||
if classifier_level == "len":
|
||||
offset = offsets["len"], 11
|
||||
elif classifier_level == "AR:0:0":
|
||||
offset = offsets["resp"][0], 1025
|
||||
elif classifier_level == "NAR:0:0":
|
||||
offset = offsets["resp"][1], 1024
|
||||
elif classifier_level == "NAR:0:1":
|
||||
offset = offsets["resp"][2], 1024
|
||||
elif classifier_level == "NAR:1:2":
|
||||
offset = offsets["resp"][3], 1024
|
||||
elif classifier_level == "NAR:2:3":
|
||||
offset = offsets["resp"][4], 1024
|
||||
elif classifier_level == "NAR:3:4":
|
||||
offset = offsets["resp"][5], 1024
|
||||
elif classifier_level == "NAR:4:5":
|
||||
offset = offsets["resp"][6], 1024
|
||||
elif classifier_level == "NAR:5:6":
|
||||
offset = offsets["resp"][7], 1024
|
||||
elif classifier_level == "NAR:6:7":
|
||||
offset = offsets["resp"][8], 1024
|
||||
else:
|
||||
continue
|
||||
|
||||
logits[batch_index] = logits[batch_index][offset[0]:offset[0]+offset[1], :]
|
||||
|
||||
else:
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
|
||||
|
|
|
@ -60,14 +60,13 @@ def is_dict_of( d, t ):
|
|||
|
||||
# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors
|
||||
def state_dict_to_tensor_metadata( data: dict, module_key=None ):
|
||||
metadata = None
|
||||
metadata = {}
|
||||
|
||||
# is a state_dict, no need to coerce
|
||||
if is_dict_of( data, torch.Tensor ):
|
||||
return data, metadata
|
||||
|
||||
# is maybe a dict with a state dict + metadata, coerce it
|
||||
metadata = {}
|
||||
target = module_key
|
||||
if not target:
|
||||
for k, v in data.items():
|
||||
|
@ -78,7 +77,8 @@ def state_dict_to_tensor_metadata( data: dict, module_key=None ):
|
|||
|
||||
# not a dict of tensors, put it as metadata
|
||||
try:
|
||||
metadata[k] = json.dumps(v)
|
||||
metadata[k] = json_stringify(v) if any([isinstance( v, dict ), isinstance( v, list )]) else v
|
||||
|
||||
if isinstance( metadata[k], bytes ):
|
||||
metadata[k] = metadata[k].decode('utf-8')
|
||||
except Exception as e:
|
||||
|
@ -96,6 +96,9 @@ def torch_save( data, path, module_key=None ):
|
|||
if ext in [".safetensor", ".safetensors", ".sft"]:
|
||||
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
return sft_save( data, path, metadata )
|
||||
|
||||
return torch.save( data, path )
|
||||
|
@ -112,13 +115,12 @@ def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=T
|
|||
|
||||
if load_metadata:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
for k, v in metadata.items():
|
||||
try:
|
||||
metadata[k] = json.loads( v )
|
||||
except Exception as e:
|
||||
pass
|
||||
state_dict = { module_key: state_dict } | metadata
|
||||
for k, v in metadata.items():
|
||||
try:
|
||||
metadata[k] = json.loads( v )
|
||||
except Exception as e:
|
||||
pass
|
||||
state_dict = { module_key: state_dict } | metadata
|
||||
|
||||
return state_dict
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user