diff --git a/vall_e/config.py b/vall_e/config.py index 3723413..8383454 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -579,7 +579,7 @@ class DeepSpeed: del ds_cfg[k] if os.path.exists("./data/ds_config.json"): - ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8"))) + ds_cfg.update(json.loads(open("./data/ds_config.json", "r", encoding="utf-8")).read()) else: ds_cfg.update(self.config) diff --git a/vall_e/data.py b/vall_e/data.py index 0bd23f6..e0d626d 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -15,7 +15,7 @@ from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge from .emb.g2p import encode as encode_phns from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size -from .utils.io import torch_save, torch_load +from .utils.io import torch_save, torch_load, json_read, json_write from collections import defaultdict from functools import cache, cached_property @@ -472,6 +472,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): metadata = {} if cfg.dataset.use_metadata and metadata_path.exists(): + #metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if len(metadata) == 0: @@ -886,8 +887,9 @@ class Dataset(_Dataset): return None if len(reference_metadata["similar"]) >= offset: offset = -1 - - return reference_metadata["similar"][offset][0] + metadata_keys = list(metadata.keys()) + index = reference_metadata["similar"][offset] + return metadata_keys[index] def sample_prompts(self, spkr_name, reference, should_trim=True): if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0: diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 32fb3a2..5f68877 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -4,7 +4,6 @@ """ import os -import orjson as json import argparse import torch import torchaudio @@ -23,6 +22,7 @@ import torchaudio.transforms as T from ..config import cfg from ..utils import truncate_json +from ..utils.io import json_read, json_write from .g2p import encode as phonemize from .qnt import encode as quantize, trim, convert_audio @@ -255,7 +255,8 @@ def main(): faiss.write_index(similarities, str(metadata_path.with_suffix(".faiss"))) return - metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {} + #metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {} + metadata = json_read( metadata_path, default={} ) metadata_keys = list(metadata.keys()) if metadata else list(similarities.keys()) for filename, sim in similarities.items(): @@ -264,9 +265,13 @@ def main(): metadata[filename]["similar"] = sim + json_write( metadata, metadata_path ) + + """ with open(str(metadata_path), "wb") as f: f.write( json.dumps( metadata ) ) #f.write( truncate_json( json.dumps( metadata ) ) ) + """ # training for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): diff --git a/vall_e/ext/retnet_hf/configuration_retnet.py b/vall_e/ext/retnet_hf/configuration_retnet.py index b842606..d409c39 100644 --- a/vall_e/ext/retnet_hf/configuration_retnet.py +++ b/vall_e/ext/retnet_hf/configuration_retnet.py @@ -6,7 +6,7 @@ from transformers.configuration_utils import PretrainedConfig def load_config_from_json(config_file): with open(config_file, 'r') as f: - config = json.load(f) + config = json.loads(f.read()) config = RetNetConfig.from_dict(config) return config diff --git a/vall_e/utils/io.py b/vall_e/utils/io.py index afc2033..cba071f 100644 --- a/vall_e/utils/io.py +++ b/vall_e/utils/io.py @@ -5,6 +5,33 @@ from pathlib import Path from safetensors import safe_open as sft_load from safetensors.torch import save_file as sft_save +try: + use_orjson = True + import orjson as json +except: + import json + +def json_stringify( data ): + return json.dumps( data ) + +def json_parse( string ): + return json.loads( string ) + +def json_read( path, default=None ): + path = coerce_path( path ) + + if not path.exists(): + return default + + with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f: + return json_parse( f.read() ) + +def json_write( data, path ): + path = coerce_path( path ) + + with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f: + f.write( json_stringify( data ) ) + def coerce_path( path ): return path if isinstance( path, Path ) else Path(path)