maybe final tweaks, I really needed to unify my json read/write and orjson is proven to be fast enough for me to try and rely on it more
This commit is contained in:
parent
6ceed866b5
commit
ebac1db16c
|
@ -579,7 +579,7 @@ class DeepSpeed:
|
|||
del ds_cfg[k]
|
||||
|
||||
if os.path.exists("./data/ds_config.json"):
|
||||
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
|
||||
ds_cfg.update(json.loads(open("./data/ds_config.json", "r", encoding="utf-8")).read())
|
||||
else:
|
||||
ds_cfg.update(self.config)
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge
|
|||
from .emb.g2p import encode as encode_phns
|
||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||
from .utils.distributed import global_rank, local_rank, world_size
|
||||
from .utils.io import torch_save, torch_load
|
||||
from .utils.io import torch_save, torch_load, json_read, json_write
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import cache, cached_property
|
||||
|
@ -472,6 +472,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
metadata = {}
|
||||
|
||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||
|
||||
if len(metadata) == 0:
|
||||
|
@ -886,8 +887,9 @@ class Dataset(_Dataset):
|
|||
return None
|
||||
if len(reference_metadata["similar"]) >= offset:
|
||||
offset = -1
|
||||
|
||||
return reference_metadata["similar"][offset][0]
|
||||
metadata_keys = list(metadata.keys())
|
||||
index = reference_metadata["similar"][offset]
|
||||
return metadata_keys[index]
|
||||
|
||||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
import orjson as json
|
||||
import argparse
|
||||
import torch
|
||||
import torchaudio
|
||||
|
@ -23,6 +22,7 @@ import torchaudio.transforms as T
|
|||
|
||||
from ..config import cfg
|
||||
from ..utils import truncate_json
|
||||
from ..utils.io import json_read, json_write
|
||||
|
||||
from .g2p import encode as phonemize
|
||||
from .qnt import encode as quantize, trim, convert_audio
|
||||
|
@ -255,7 +255,8 @@ def main():
|
|||
faiss.write_index(similarities, str(metadata_path.with_suffix(".faiss")))
|
||||
return
|
||||
|
||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {}
|
||||
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {}
|
||||
metadata = json_read( metadata_path, default={} )
|
||||
metadata_keys = list(metadata.keys()) if metadata else list(similarities.keys())
|
||||
|
||||
for filename, sim in similarities.items():
|
||||
|
@ -264,9 +265,13 @@ def main():
|
|||
|
||||
metadata[filename]["similar"] = sim
|
||||
|
||||
json_write( metadata, metadata_path )
|
||||
|
||||
"""
|
||||
with open(str(metadata_path), "wb") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
#f.write( truncate_json( json.dumps( metadata ) ) )
|
||||
"""
|
||||
|
||||
# training
|
||||
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
||||
|
|
|
@ -6,7 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||
|
||||
def load_config_from_json(config_file):
|
||||
with open(config_file, 'r') as f:
|
||||
config = json.load(f)
|
||||
config = json.loads(f.read())
|
||||
config = RetNetConfig.from_dict(config)
|
||||
return config
|
||||
|
||||
|
|
|
@ -5,6 +5,33 @@ from pathlib import Path
|
|||
from safetensors import safe_open as sft_load
|
||||
from safetensors.torch import save_file as sft_save
|
||||
|
||||
try:
|
||||
use_orjson = True
|
||||
import orjson as json
|
||||
except:
|
||||
import json
|
||||
|
||||
def json_stringify( data ):
|
||||
return json.dumps( data )
|
||||
|
||||
def json_parse( string ):
|
||||
return json.loads( string )
|
||||
|
||||
def json_read( path, default=None ):
|
||||
path = coerce_path( path )
|
||||
|
||||
if not path.exists():
|
||||
return default
|
||||
|
||||
with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f:
|
||||
return json_parse( f.read() )
|
||||
|
||||
def json_write( data, path ):
|
||||
path = coerce_path( path )
|
||||
|
||||
with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f:
|
||||
f.write( json_stringify( data ) )
|
||||
|
||||
def coerce_path( path ):
|
||||
return path if isinstance( path, Path ) else Path(path)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user