diff --git a/scripts/hf_test.py b/scripts/hf_test.py new file mode 100644 index 0000000..9b2e1a0 --- /dev/null +++ b/scripts/hf_test.py @@ -0,0 +1,52 @@ +import torch +from transformers import LlamaForCausalLM, LlamaTokenizer + +# tokenizer = LlamaTokenizer.from_pretrained("./training/llama-encodec-ar+nar-len/hf/") +model = LlamaForCausalLM.from_pretrained("./training/llama-encodec-ar+nar-len/hf/") + +phns = [1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2] +proms = [ + [780,835,835,835,339,395,798,537,537,537,537,222,76,989,548,65,705,375,261,375,297,503,529,571,707,346,464,862,148,496,574,115,115,438,934,339,865,876,63,40,779,461,602,794,10,220,398,869,639,705,869,917,705,893,215,705,869,938,439,175,139,506,375,529,297,705,651,238,962,461,195,441,377,581,473,795,644,626,459,981,767,670,696,73,779,257,408,1017,1019,133,133,1017,835,604,699,626,67,92,707,92,179,179,772,869,441,799,917,238,745,904,904,904,106,133,1019,1017,1017,395,883,87,519,594,1002,682,996,540,186,1019,430,202,347,889,61,92,542,297,67,669,571,707,346,67,359,571,707,669,604,25,1008,810,35,621,67,600,333,123,284,568,817,243,778,464,638,610,359,538,464,975,321,700,377,484,179,284,284,621,538,464,745,171,171,159,744,159,287,461,69,15,529,67,92,669,464,515,605,24,822,865,293,62,172,638,359,562,138,839,846,775,556,688,1006,917,297,312,148,331,496,646,67,314,15,705,131,855,662,287,172,85,538,519,762,450,391,609,643,778,80,287,794,794,115,785,794,461,699,519,932,522,652,262,508,902,932,932,391,769,18,507,90,442,762,610,610,669,605,310,855,56,989,863,195,464,604,257,904,632,786,951,461,239,195,878,771,146,481,146,481,434,643,917,280,67,464,115,744,744,115,115,115,819,709,63,368,359,519,996,616,464,996,616,519,762,917,841,772,568,954,600,422,893,592,464,626,86,143,615,171,744,744,196,115,821,415,521,799,654,839,644,473,592,953,523,855,738,855,876,876,1017,63,329] +] +sep = [17685] +rvq_lvl = [17666] +lang = [17686] +len_seq = [17674] + +for i, t in enumerate( proms[0] ): + proms[0][i] = t + 256 + 1024 + +ids = torch.tensor(phns + sep + lang + sep + rvq_lvl + sep + proms[0] + sep + len_seq, device="cuda", dtype=torch.int32) +pos_ids = torch.tensor( [*range(len(phns)+1)] + [*range(2)] + [*range(2)] + [*range(len(proms[0])+1)] + [0], device="cuda", dtype=torch.int32) + +start = 17674 # 8448 +end = start + 10 # 1025 + +with torch.no_grad(): + original_lm_head = model.lm_head.weight + + model.lm_head = torch.nn.Linear(1024, end - start, bias=False) + model.lm_head.weight.copy_(original_lm_head[start:end]) + +model.to(device="cuda", dtype=torch.float16) +model.eval() + +n_decoded = 0 +while True: + out = model(input_ids=ids.unsqueeze(0), position_ids=pos_ids.unsqueeze(0)) + + #logits = out.logits[0, -1:, start:end] + logits = out.logits[0, -1:, :] + tokens = logits.argmax(dim=-1) + n_decoded += 1 + + print( n_decoded, tokens ) + + if end in tokens or n_decoded > 5: + break + + ids = torch.concat( [ ids, tokens + start ] ) + pos_ids = torch.concat( [ pos_ids, torch.tensor([n_decoded]).to(pos_ids) ] ) + +print( out ) +print( ids ) \ No newline at end of file diff --git a/vall_e.cpp/README.md b/vall_e.cpp/README.md index 30a3cf9..7035f8e 100644 --- a/vall_e.cpp/README.md +++ b/vall_e.cpp/README.md @@ -10,7 +10,9 @@ Populate `./include/` with the `llama.cpp` and `encodec.cpp` headers. Populate `./libs/` with the compiled libraries of `llama.cpp` and `encodec.cpp`. * `encodec.cpp` requires updating `ggml` to the latest version and doing a quick hack to make it work on the CPU backend. -* `llama.cpp` currently requires no hacks, but would be *very* nice to hack in a way to retrieve a model's `tok_embd`. +* `llama.cpp` currently requires no hacks, but: + * would be *very* nice to retrieve a model's `tok_embd` through the API. + * would be ***very*** nice to only specify a slice of the output head through the API. Run `make`. diff --git a/vall_e.cpp/vall_e.cpp b/vall_e.cpp/vall_e.cpp index 9edd8c0..5716e78 100644 --- a/vall_e.cpp/vall_e.cpp +++ b/vall_e.cpp/vall_e.cpp @@ -250,18 +250,19 @@ std::vector decode_audio( struct encodec_context* ectx, const std::vector } const int EMBEDDING_MODE_PROM = 0; -const int EMBEDDING_MODE_RESP_AR_NAR = 0; -const int EMBEDDING_MODE_RESP_NAR_LEN = 0; +const int EMBEDDING_MODE_RESP_AR_NAR = 1; +const int EMBEDDING_MODE_RESP_NAR_LEN = 2; const int INFERENCE_MODE_LEN = 0; const int INFERENCE_MODE_AR = 1; const int INFERENCE_MODE_NAR_DEMASK = 2; -const int INFERENCE_MODE_NAR = 4; +const int INFERENCE_MODE_NAR = 3; const int MODALITY_AR_NAR = 0; -const int MODALITY_NAR_LEN = 0; +const int MODALITY_NAR_LEN = 1; const int MAX_DURATION = 75; // * 12; +const int CTX_SIZE = 2048; // sums embeddings over a 2D "tensor" std::vector> sum_embeddings( const std::vector>& input, int n_embd, int rvq_l, float** embds, int mode = EMBEDDING_MODE_PROM ) { @@ -457,14 +458,14 @@ std::vector generate( llama_context* ctx, llama_model* model, llama int main(int argc, char ** argv) { // to-do: replace all of this with proper loading code int32_t ngl = 0; - int modality = MODALITY_AR_NAR; + int modality = MODALITY_NAR_LEN; input_t input{}; embeddings_t embeddings_map{}; // input.phonemes = "hˈɛloː ʋˈɔrlt"; input.phn = {1,85,4,128,26,4,186,4,89,33,25,4,48,4,134,25,52,86,4,34,97,27,11,2}; // hˈɛloː ʋˈɔrlt - std::string vall_e_model_path = "./data/vall_e-F16.gguf"; + std::string vall_e_model_path = "./data/vall_e-f16.gguf"; std::string encodec_model_path = "./data/encodec.bin"; std::string input_prompt_path = "./data/prom.wav"; std::string output_response_path = "./data/resp.wav"; @@ -497,9 +498,9 @@ int main(int argc, char ** argv) { // initialize the context llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = 22500; - ctx_params.n_batch = 22500; - ctx_params.n_ubatch = 22500; + ctx_params.n_ctx = CTX_SIZE; + ctx_params.n_batch = CTX_SIZE; + ctx_params.n_ubatch = CTX_SIZE; ctx_params.no_perf = false; ctx_params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; @@ -519,7 +520,7 @@ int main(int argc, char ** argv) { llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_k(20)); llama_sampler_chain_add(smpl_ar, llama_sampler_init_top_p(0.9, 20)); llama_sampler_chain_add(smpl_ar, llama_sampler_init_temp (1.0)); - // llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130)); + llama_sampler_chain_add(smpl_ar, llama_sampler_init_dist (1130)); llama_sampler_chain_add(smpl_nar, llama_sampler_init_greedy()); @@ -542,13 +543,13 @@ int main(int argc, char ** argv) { if ( input.phonemes != "" ) { const int n_prompt = -llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), NULL, 0, true, true); // allocate space for the tokens and tokenize the input.phonemes - input.phns.resize(n_prompt) - if (llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), input.phns.data(), input.phns.size(), true, true) < 0) { + input.phn.resize(n_prompt); + if (llama_tokenize(model, input.phonemes.c_str(), input.phonemes.size(), input.phn.data(), input.phn.size(), true, true) < 0) { fprintf(stderr, "%s: error: failed to tokenize: %s\n", __func__, input.phonemes.c_str()); return 1; } - for ( auto& token : input.phns ) printf("%i ", token ); + for ( auto& token : input.phn ) printf("%i ", token ); printf("\n"); } diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 089da0b..b5c2b38 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -114,6 +114,7 @@ def load_engines(training=True, **model_kwargs): "lr": cfg.hyperparameters.learning_rate, } + if cfg.hyperparameters.optimizer.lower() == "adamw": params["betas"] = (0.9, 0.96) params["eps"] = 1e-07 diff --git a/vall_e/export.py b/vall_e/export.py index d0dbe69..d740316 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -71,20 +71,25 @@ def convert_to_hf( state_dict, config = None, save_path = None ): "stt", ] - classifier_bias = False + classifier_bias = "classifiers.proj.0.bias" in state_dict['module'] # cfg.model.experimental.classifiers_bias + split_classifiers = "classifiers.proj.0.weight" in state_dict['module'] # cfg.model.experimental.split_classifiers embedding = torch.nn.Embedding( n_tokens, model_dim ) classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias ) + if not split_classifiers: + classifier.weight[:] = state_dict['module']['classifier.weight'][:] + # to-do: ignore classifier for RVQ level 7 # inject text tokens token_start = 0 token_end = l_tokens[0] embedding.weight[token_start:token_end] = state_dict['module']['text_emb.weight'] - classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight'] - if classifier_bias: - classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias'] + if split_classifiers: + classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.9.weight'] + if classifier_bias: + classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.9.bias'] # tokenizer already has these tokens # inject prom tokens @@ -104,9 +109,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ): token_start = token_end token_end += l_tokens[2] // 2 embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.0.weight'] - classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight'] - if classifier_bias: - classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias'] + if split_classifiers: + classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.0.weight'] + if classifier_bias: + classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.0.bias'] for t in range(n_audio_tokens): tokenizer_vocab[f'<|AR|0:0|{t}|>'] = token_start + t tokenizer_vocab[f''] = token_start + 1024 @@ -115,9 +121,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ): token_start = token_end token_end += l_tokens[2] // 2 embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight'] - classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight'] - if classifier_bias: - classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias'] + if split_classifiers: + classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight'] + if classifier_bias: + classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias'] for t in range(n_audio_tokens): tokenizer_vocab[f'<|NAR|0:0|{t}|>'] = token_start + t tokenizer_vocab[f'<|NAR|0:0|STOP|>'] = token_start + 1024 @@ -129,9 +136,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ): start = token_start + ((l-1) * n_audio_tokens) end = start + n_audio_tokens embedding.weight[start:end] = state_dict['module'][f'resps_emb.embeddings.{l}.weight'] - classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight'] - if classifier_bias: - classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias'] + if split_classifiers: + classifier.weight[start:end] = state_dict['module'][f'classifiers.proj.{l}.weight'] + if classifier_bias: + classifier.bias[start:end] = state_dict['module'][f'classifiers.proj.{l}.bias'] for t in range(n_audio_tokens): tokenizer_vocab[f'<|NAR|{l-1}:{l}|{t}|>'] = start + t @@ -147,9 +155,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ): token_start = token_end token_end += l_tokens[5] embedding.weight[token_start:token_end] = state_dict['module'][f'len_emb.weight'] - classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256 - if classifier_bias: - classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256 + if split_classifiers: + classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.10.weight'][0:n_len_tokens] # erroneously sized as 256 + if classifier_bias: + classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.10.bias'][0:n_len_tokens] # erroneously sized as 256 for t in range(n_len_tokens): tokenizer_vocab[f'<|len:{t}|>'] = token_start + t @@ -197,7 +206,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ): out_dir = cfg.rel_path / "hf" out_dir.mkdir(parents=True, exist_ok=True) # write weights - torch_save( model_dict, out_dir / "model.safetensors" ) + torch_save( { "module": model_dict, "format": "pt" }, out_dir / "model.safetensors" ) # write tokenizer.json tokenizer['model']['vocab'] |= tokenizer_vocab json_write(tokenizer, out_dir / "tokenizer.json", pretty=True) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 9747507..0d47257 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -55,6 +55,38 @@ task_outputs = { "len": "len", } +# yuck +def _get_offsets(): + return { + "text": 0, # + "quant_level": 17666, # <|RVQ:0> + "len": 17674, # <|len:0|> + "lang": 17686, # <|lang:en|>" + "task": 17692, # <|task:tts|> + "sep": 17685, # <|sep|> + "prom": [ + 256 + (1024 * 0), # <|P|0:0|> + 256 + (1024 * 1), # <|P|1:0|> + 256 + (1024 * 2), # <|P|2:0|> + 256 + (1024 * 3), # <|P|3:0|> + 256 + (1024 * 4), # <|P|4:0|> + 256 + (1024 * 5), # <|P|5:0|> + 256 + (1024 * 6), # <|P|6:0|> + 256 + (1024 * 7), # <|P|7:0|> + ], + "resp": [ + 8448, # <|AR|0:0|> + 9473, # <|NAR|0:0|> + 10498 + (1024 * 0), # <|NAR|0:1|> + 10498 + (1024 * 1), # <|NAR|1:2|> + 10498 + (1024 * 2), # <|NAR|2:3|> + 10498 + (1024 * 3), # <|NAR|3:4|> + 10498 + (1024 * 4), # <|NAR|4:5|> + 10498 + (1024 * 5), # <|NAR|5:6|> + 10498 + (1024 * 6), # <|NAR|6:7|> + ] + } + def _dropout_mask( input, p=None ): # cosine scheduling if p is None: @@ -494,6 +526,9 @@ class Base(nn.Module): classifier_l_tokens += [ 11 ] classifier_l_names += ["len"] + n_vocab = 17701 if not split_classifiers else n_resp_tokens + 1 + + self.n_vocab = n_vocab self.unified_position_ids = unified_position_ids self.interleave = interleave self.layerskip = layerskip @@ -601,7 +636,7 @@ class Base(nn.Module): elif self.arch_type in ["mistral", "mixtral"]: if n_experts <= 1: self.model = MistralModel(MistralConfig( - vocab_size=n_resp_tokens, + vocab_size=n_vocab, hidden_size=d_model, max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds intermediate_size=d_model*4, @@ -647,7 +682,7 @@ class Base(nn.Module): if n_experts <= 1: self.model = LlamaClass(LlamaConfig( - vocab_size=n_resp_tokens, + vocab_size=n_vocab, hidden_size=d_model, max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds intermediate_size=d_model*4, @@ -700,7 +735,7 @@ class Base(nn.Module): )) elif self.arch_type == "retnet": kwargs = dict( - vocab_size=n_resp_tokens, + vocab_size=n_vocab, decoder_embed_dim=d_model, decoder_value_embed_dim =d_model * 2, decoder_retention_heads=n_heads, @@ -732,7 +767,7 @@ class Base(nn.Module): self.model = RetNetDecoder(RetNetConfig(**kwargs)) elif self.arch_type in ["mamba2"]: self.model = Mamba2Model(Mamba2Config( - vocab_size=n_resp_tokens, + vocab_size=n_vocab, hidden_size=d_model, expand=2, num_hidden_layers=n_layers*2, @@ -744,7 +779,7 @@ class Base(nn.Module): )) elif self.arch_type in ["mamba"]: self.model = MambaModel(MambaConfig( - vocab_size=n_resp_tokens, + vocab_size=n_vocab, hidden_size=d_model, expand=2, num_hidden_layers=n_layers*2, @@ -761,11 +796,11 @@ class Base(nn.Module): del self.model.embeddings if not split_classifiers: - self.classifier = nn.Linear(d_model, n_resp_tokens) + self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias) self.classifiers = None self.accuracy_metric = MulticlassAccuracy( - n_resp_tokens, + n_vocab, top_k=10, average="micro", multidim_average="global", @@ -773,7 +808,7 @@ class Base(nn.Module): ) self.precision_metric = MulticlassPrecision( - n_resp_tokens, + n_vocab, top_k=10, average="micro", multidim_average="global", @@ -1031,6 +1066,48 @@ class Base(nn.Module): raise Exception(f'Unrecognized task: {task_type}') return inputs + def offset_inputs( + self, + inputs: list, + direction: int = 1, # -1 to de-offset + ): + offsets = _get_offsets() + + for batch_index, batch_input in enumerate(inputs): + quant_level = None + classifier_level = None + # pre-iterate + for name, input in batch_input: + if name == "quant_level": + quant_level = input + elif name == "classifier_level": + classifier_level = input + + for name, input in batch_input: + if name not in offsets: + continue + + if not isinstance( input, torch.Tensor ): + continue + + offset = offsets[name] + if name in ["prom", "resp"]: + l = quant_level + if name == "resp": + if classifier_level == "AR:0:0": + l = 0 + elif classifier_level == "NAR:0:0": + l = 1 + else: + l = 2 + (quant_level-1) + + offset = offset[l] + + for i, t in enumerate( input ): + input[i] += offset * direction + + return inputs + def inputs_to_embeddings( self, inputs: list, @@ -1366,6 +1443,49 @@ class Base(nn.Module): if not isinstance(token, torch.Tensor): continue + # offset to flattened vocab ranges + if self.classifier is not None: + offsets = _get_offsets() + if name in offsets: + offset = offsets[name] + # yes there's a better way + if name == "prom": + offset = offset[quant_level] + elif name == "resp": + """ + if classifier_level == "AR:0:0": + offset = offset[0] + elif classifier_level == "NAR:0:0": + offset = offset[1] + elif classifier_level == "NAR:0:1": + offset = offset[2] + elif classifier_level == "NAR:1:2": + offset = offset[3] + elif classifier_level == "NAR:2:3": + offset = offset[4] + elif classifier_level == "NAR:3:4": + offset = offset[5] + elif classifier_level == "NAR:4:5": + offset = offset[6] + elif classifier_level == "NAR:5:6": + offset = offset[7] + elif classifier_level == "NAR:6:7": + offset = offset[8] + else: + continue + """ + if classifier_level == "AR:0:0": + offset = offset[0] + elif classifier_level == "NAR:0:0": + offset = offset[1] + else: + offset = offset[2 + (quant_level-1)] + + for i, t in enumerate( token ): + if t == self.ignore_index: + continue + token[i] += offset + if token.is_floating_point(): ignored = True @@ -1422,7 +1542,7 @@ class Base(nn.Module): # perofrm loss calculation on the entire sequence if not self.config.loss_factors: - target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) + target = _join( target, torch.tensor(self.ignore_index if self.classifier is None else 17685, device=target[-1].device) ) logit = logits[batch_index] # shift if causal @@ -1606,6 +1726,37 @@ class Base(nn.Module): self.loss = None self.stats = None + + # de-offset if needed + if self.classifier is not None: + offsets = _get_offsets() + for batch_index, classifier_level in enumerate( classifier_levels ): + # yes there's a better way + if classifier_level == "len": + offset = offsets["len"], 11 + elif classifier_level == "AR:0:0": + offset = offsets["resp"][0], 1025 + elif classifier_level == "NAR:0:0": + offset = offsets["resp"][1], 1024 + elif classifier_level == "NAR:0:1": + offset = offsets["resp"][2], 1024 + elif classifier_level == "NAR:1:2": + offset = offsets["resp"][3], 1024 + elif classifier_level == "NAR:2:3": + offset = offsets["resp"][4], 1024 + elif classifier_level == "NAR:3:4": + offset = offsets["resp"][5], 1024 + elif classifier_level == "NAR:4:5": + offset = offsets["resp"][6], 1024 + elif classifier_level == "NAR:5:6": + offset = offsets["resp"][7], 1024 + elif classifier_level == "NAR:6:7": + offset = offsets["resp"][8], 1024 + else: + continue + + logits[batch_index] = logits[batch_index][offset[0]:offset[0]+offset[1], :] + else: loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) diff --git a/vall_e/utils/io.py b/vall_e/utils/io.py index 7ab73bb..5b1239c 100644 --- a/vall_e/utils/io.py +++ b/vall_e/utils/io.py @@ -60,14 +60,13 @@ def is_dict_of( d, t ): # handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors def state_dict_to_tensor_metadata( data: dict, module_key=None ): - metadata = None + metadata = {} # is a state_dict, no need to coerce if is_dict_of( data, torch.Tensor ): return data, metadata # is maybe a dict with a state dict + metadata, coerce it - metadata = {} target = module_key if not target: for k, v in data.items(): @@ -78,7 +77,8 @@ def state_dict_to_tensor_metadata( data: dict, module_key=None ): # not a dict of tensors, put it as metadata try: - metadata[k] = json.dumps(v) + metadata[k] = json_stringify(v) if any([isinstance( v, dict ), isinstance( v, list )]) else v + if isinstance( metadata[k], bytes ): metadata[k] = metadata[k].decode('utf-8') except Exception as e: @@ -96,6 +96,9 @@ def torch_save( data, path, module_key=None ): if ext in [".safetensor", ".safetensors", ".sft"]: data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key ) + if metadata is None: + metadata = {} + return sft_save( data, path, metadata ) return torch.save( data, path ) @@ -112,13 +115,12 @@ def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=T if load_metadata: metadata = f.metadata() - if metadata is not None: - for k, v in metadata.items(): - try: - metadata[k] = json.loads( v ) - except Exception as e: - pass - state_dict = { module_key: state_dict } | metadata + for k, v in metadata.items(): + try: + metadata[k] = json.loads( v ) + except Exception as e: + pass + state_dict = { module_key: state_dict } | metadata return state_dict