maybe final tweaks, I really needed to unify my json read/write and orjson is proven to be fast enough for me to try and rely on it more
This commit is contained in:
parent
6ceed866b5
commit
ebac1db16c
|
@ -579,7 +579,7 @@ class DeepSpeed:
|
||||||
del ds_cfg[k]
|
del ds_cfg[k]
|
||||||
|
|
||||||
if os.path.exists("./data/ds_config.json"):
|
if os.path.exists("./data/ds_config.json"):
|
||||||
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
|
ds_cfg.update(json.loads(open("./data/ds_config.json", "r", encoding="utf-8")).read())
|
||||||
else:
|
else:
|
||||||
ds_cfg.update(self.config)
|
ds_cfg.update(self.config)
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge
|
||||||
from .emb.g2p import encode as encode_phns
|
from .emb.g2p import encode as encode_phns
|
||||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||||
from .utils.distributed import global_rank, local_rank, world_size
|
from .utils.distributed import global_rank, local_rank, world_size
|
||||||
from .utils.io import torch_save, torch_load
|
from .utils.io import torch_save, torch_load, json_read, json_write
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import cache, cached_property
|
from functools import cache, cached_property
|
||||||
|
@ -472,6 +472,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||||
|
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||||
|
|
||||||
if len(metadata) == 0:
|
if len(metadata) == 0:
|
||||||
|
@ -886,8 +887,9 @@ class Dataset(_Dataset):
|
||||||
return None
|
return None
|
||||||
if len(reference_metadata["similar"]) >= offset:
|
if len(reference_metadata["similar"]) >= offset:
|
||||||
offset = -1
|
offset = -1
|
||||||
|
metadata_keys = list(metadata.keys())
|
||||||
return reference_metadata["similar"][offset][0]
|
index = reference_metadata["similar"][offset]
|
||||||
|
return metadata_keys[index]
|
||||||
|
|
||||||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
|
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import orjson as json
|
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
@ -23,6 +22,7 @@ import torchaudio.transforms as T
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from ..utils import truncate_json
|
from ..utils import truncate_json
|
||||||
|
from ..utils.io import json_read, json_write
|
||||||
|
|
||||||
from .g2p import encode as phonemize
|
from .g2p import encode as phonemize
|
||||||
from .qnt import encode as quantize, trim, convert_audio
|
from .qnt import encode as quantize, trim, convert_audio
|
||||||
|
@ -255,7 +255,8 @@ def main():
|
||||||
faiss.write_index(similarities, str(metadata_path.with_suffix(".faiss")))
|
faiss.write_index(similarities, str(metadata_path.with_suffix(".faiss")))
|
||||||
return
|
return
|
||||||
|
|
||||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {}
|
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if metadata_path.exists() else {}
|
||||||
|
metadata = json_read( metadata_path, default={} )
|
||||||
metadata_keys = list(metadata.keys()) if metadata else list(similarities.keys())
|
metadata_keys = list(metadata.keys()) if metadata else list(similarities.keys())
|
||||||
|
|
||||||
for filename, sim in similarities.items():
|
for filename, sim in similarities.items():
|
||||||
|
@ -264,9 +265,13 @@ def main():
|
||||||
|
|
||||||
metadata[filename]["similar"] = sim
|
metadata[filename]["similar"] = sim
|
||||||
|
|
||||||
|
json_write( metadata, metadata_path )
|
||||||
|
|
||||||
|
"""
|
||||||
with open(str(metadata_path), "wb") as f:
|
with open(str(metadata_path), "wb") as f:
|
||||||
f.write( json.dumps( metadata ) )
|
f.write( json.dumps( metadata ) )
|
||||||
#f.write( truncate_json( json.dumps( metadata ) ) )
|
#f.write( truncate_json( json.dumps( metadata ) ) )
|
||||||
|
"""
|
||||||
|
|
||||||
# training
|
# training
|
||||||
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
||||||
|
|
|
@ -6,7 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
def load_config_from_json(config_file):
|
def load_config_from_json(config_file):
|
||||||
with open(config_file, 'r') as f:
|
with open(config_file, 'r') as f:
|
||||||
config = json.load(f)
|
config = json.loads(f.read())
|
||||||
config = RetNetConfig.from_dict(config)
|
config = RetNetConfig.from_dict(config)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,33 @@ from pathlib import Path
|
||||||
from safetensors import safe_open as sft_load
|
from safetensors import safe_open as sft_load
|
||||||
from safetensors.torch import save_file as sft_save
|
from safetensors.torch import save_file as sft_save
|
||||||
|
|
||||||
|
try:
|
||||||
|
use_orjson = True
|
||||||
|
import orjson as json
|
||||||
|
except:
|
||||||
|
import json
|
||||||
|
|
||||||
|
def json_stringify( data ):
|
||||||
|
return json.dumps( data )
|
||||||
|
|
||||||
|
def json_parse( string ):
|
||||||
|
return json.loads( string )
|
||||||
|
|
||||||
|
def json_read( path, default=None ):
|
||||||
|
path = coerce_path( path )
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
return default
|
||||||
|
|
||||||
|
with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f:
|
||||||
|
return json_parse( f.read() )
|
||||||
|
|
||||||
|
def json_write( data, path ):
|
||||||
|
path = coerce_path( path )
|
||||||
|
|
||||||
|
with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f:
|
||||||
|
f.write( json_stringify( data ) )
|
||||||
|
|
||||||
def coerce_path( path ):
|
def coerce_path( path ):
|
||||||
return path if isinstance( path, Path ) else Path(path)
|
return path if isinstance( path, Path ) else Path(path)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user