maybe final tweaks, I really needed to unify my json read/write and orjson is proven to be fast enough for me to try and rely on it more

This commit is contained in:
mrq 2024-09-17 22:57:04 -05:00
parent 6ceed866b5
commit ebac1db16c
5 changed files with 41 additions and 7 deletions

View File

@ -579,7 +579,7 @@ class DeepSpeed:
del ds_cfg[k]
if os.path.exists("./data/ds_config.json"):
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
ds_cfg.update(json.loads(open("./data/ds_config.json", "r", encoding="utf-8")).read())
else:
ds_cfg.update(self.config)

View File

@ -15,7 +15,7 @@ from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge
from .emb.g2p import encode as encode_phns
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size
from .utils.io import torch_save, torch_load
from .utils.io import torch_save, torch_load, json_read, json_write
from collections import defaultdict
from functools import cache, cached_property
@ -472,6 +472,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
metadata = {}
if cfg.dataset.use_metadata and metadata_path.exists():
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
if len(metadata) == 0:
@ -886,8 +887,9 @@ class Dataset(_Dataset):
return None
if len(reference_metadata["similar"]) >= offset:
offset = -1
return reference_metadata["similar"][offset][0]
metadata_keys = list(metadata.keys())
index = reference_metadata["similar"][offset]
return metadata_keys[index]
def sample_prompts(self, spkr_name, reference, should_trim=True):
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:

View File

@ -4,7 +4,6 @@
"""
import os
import orjson as json
import argparse
import torch
import torchaudio
@ -23,6 +22,7 @@ import torchaudio.transforms as T
from ..config import cfg
from ..utils import truncate_json
from ..utils.io import json_read, json_write
from .g2p import encode as phonemize
from .qnt import encode as quantize, trim, convert_audio
@ -255,7 +255,8 @@ def main():
faiss.write_index(similarities, str(metadata_path.with_suffix(".faiss")))
return
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {}
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {}
metadata = json_read( metadata_path, default={} )
metadata_keys = list(metadata.keys()) if metadata else list(similarities.keys())
for filename, sim in similarities.items():
@ -264,9 +265,13 @@ def main():
metadata[filename]["similar"] = sim
json_write( metadata, metadata_path )
"""
with open(str(metadata_path), "wb") as f:
f.write( json.dumps( metadata ) )
#f.write( truncate_json( json.dumps( metadata ) ) )
"""
# training
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):

View File

@ -6,7 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
def load_config_from_json(config_file):
with open(config_file, 'r') as f:
config = json.load(f)
config = json.loads(f.read())
config = RetNetConfig.from_dict(config)
return config

View File

@ -5,6 +5,33 @@ from pathlib import Path
from safetensors import safe_open as sft_load
from safetensors.torch import save_file as sft_save
try:
use_orjson = True
import orjson as json
except:
import json
def json_stringify( data ):
return json.dumps( data )
def json_parse( string ):
return json.loads( string )
def json_read( path, default=None ):
path = coerce_path( path )
if not path.exists():
return default
with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f:
return json_parse( f.read() )
def json_write( data, path ):
path = coerce_path( path )
with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f:
f.write( json_stringify( data ) )
def coerce_path( path ):
return path if isinstance( path, Path ) else Path(path)