diff --git a/vall_e/config.py b/vall_e/config.py index ec0a7a3..ca570ba 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -22,7 +22,7 @@ from pathlib import Path from .utils.distributed import world_size from .utils.io import torch_load -from .utils import set_seed, prune_missing +from .utils import set_seed, prune_missing, md5_hash @dataclass() class BaseConfig: @@ -200,6 +200,9 @@ class Dataset: _frames_per_second: int = 0 # allows setting your own hint + def hash_key(self, *args): + return md5_hash([ self.use_hdf5, self.min_duration, self.max_duration ] + [*args]) + @cached_property def frames_per_second(self): if self._frames_per_second > 0: diff --git a/vall_e/data.py b/vall_e/data.py index 2d2eba7..eb501f5 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -15,7 +15,7 @@ from .config import cfg from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt, pad_codes_with_silence from .emb.g2p import encode as encode_phns from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler -from .utils.distributed import global_rank, local_rank, world_size +from .utils.distributed import global_rank, local_rank, world_size, is_global_leader from .utils.io import torch_save, torch_load, json_read, json_write, json_stringify, json_parse from .utils import setup_logging @@ -531,14 +531,36 @@ def _get_phone_path(path): return _replace_file_extension(path, _get_phone_extension()) _durations_map = {} -# makeshift caching the above to disk -@cfg.diskcache() def _get_duration_map( type="training" ): return _durations_map[type] if type in _durations_map else {} -@cfg.diskcache() def _load_paths(dataset, type="training", silent=False): - return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) } + cached_dir = cfg.cache_dir / cfg.dataset.hash_key(sorted(dataset)) + + cached_durations_path = cached_dir / f"durations[{type}].json" + cached_paths_path = cached_dir / f"dataloader[{type}].json" + + # load the duration table first, since this is independent from the loaded paths + if cached_durations_path.exists(): + _durations_map[type] = json_read( cached_durations_path ) + + # load the cached valid paths (if we're requesting cache use) + if cached_paths_path.exists() and cfg.dataset.cache: + # to-do: automatic conversion between HDF5 formatted paths and on-disk paths + return json_read( cached_paths_path ) + + # deduce valid paths + paths = { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) } + + # and write if global leader (to avoid other processes writing to the same file at once) + if is_global_leader(): + if not cached_dir.exists(): + cached_dir.mkdir(parents=True, exist_ok=True) + + json_write( _durations_map[type], cached_durations_path, truncate=True ) + json_write( paths, cached_paths_path, truncate=True ) + + return paths def _load_paths_from_metadata(group_name, type="training", validate=False): data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name @@ -685,6 +707,10 @@ class Dataset(_Dataset): # dict of paths keyed by speaker names self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type) + # do it here due to the above + self.duration = 0 + self.duration_map = _get_duration_map( self.dataset_type ) + self.duration_buckets = {} # cull speakers if they do not have enough utterances if cfg.dataset.min_utterances > 0: @@ -716,11 +742,6 @@ class Dataset(_Dataset): self.paths_by_spkr_name[name] = [] self.paths_by_spkr_name[name].append( path ) - # do it here due to the above - self.duration = 0 - self.duration_map = _get_duration_map( self.dataset_type ) - self.duration_buckets = {} - # store in corresponding bucket for path in self.paths: duration = self.duration_map[path] @@ -759,7 +780,9 @@ class Dataset(_Dataset): # just interleave self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] - + # dereference buckets + self.duration_map = None + self.duration_buckets = None # dict of speakers keyed by speaker group self.spkrs_by_spkr_group = {} @@ -1414,6 +1437,9 @@ def create_train_dataloader(): _logger.info(f"#samples (train): {len(train_dataset)}.") _logger.info(f"#duration (train): {str(train_dataset.duration)}.") + # remove duration map (it gets bloated) + _durations_map = {} + return train_dl def create_val_dataloader(): @@ -1427,8 +1453,12 @@ def create_val_dataloader(): _logger.info(f"#samples (val): {len(val_dataset)}.") _logger.info(f"#duration (val): {str(val_dataset.duration)}.") + # remove duration map (it gets bloated) + _durations_map = {} + return val_dl +# to-do, use the above two, then create the subtrain dataset def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() @@ -1456,6 +1486,9 @@ def create_train_val_dataloader(): assert isinstance(subtrain_dl.dataset, Dataset) + # remove duration map (it gets bloated) + _durations_map = {} + return train_dl, subtrain_dl, val_dl # parse metadata from an numpy file (.enc/.dac) and validate it @@ -1506,7 +1539,7 @@ def remap_speaker_name( name ): return name # parse dataset into better to sample metadata -def create_dataset_metadata( skip_existing=True ): +def create_dataset_metadata( skip_existing=False ): symmap = get_phone_symmap() root = str(cfg.data_dir) @@ -1557,6 +1590,8 @@ def create_dataset_metadata( skip_existing=True ): qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16) utterance_metadata = process_artifact_metadata( artifact ) + # to-do: derive duration from codes if duration is malformed because this happened to me with LibriTTS-R + #utterance_metadata["duration"] = qnt.shape[0] / cfg.dataset.frames_per_second for k, v in utterance_metadata.items(): metadata[id][k] = v diff --git a/vall_e/utils/__init__.py b/vall_e/utils/__init__.py index 4c1273b..367e35a 100755 --- a/vall_e/utils/__init__.py +++ b/vall_e/utils/__init__.py @@ -13,5 +13,6 @@ from .utils import ( truncate_json, timer, prune_missing, - clamp + clamp, + md5_hash ) \ No newline at end of file diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index db81285..9f85087 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -15,6 +15,7 @@ import time import psutil import math import logging +import hashlib _logger = logging.getLogger(__name__) @@ -31,6 +32,11 @@ from datetime import datetime T = TypeVar("T") +def md5_hash( x ): + if isinstance( x, list ): + return md5_hash(":".join([ md5_hash( _ ) for _ in x ])) + return hashlib.md5(str(x).encode("utf-8")).hexdigest() + def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ): is_obj = hasattr( source, "__dict__" ) if parent_is_obj is None: @@ -69,12 +75,14 @@ class timer: print(f'[{datetime.now().isoformat()}] {msg}') -def truncate_json( str ): +def truncate_json( x ): + if isinstance( x, bytes ): + return truncate_json( x.decode('utf-8') ).encode() def fun( match ): return "{:.4f}".format(float(match.group())) - return re.sub(r"\d+\.\d{8,}", fun, str) + return re.sub(r"\d+\.\d{8,}", fun, x) def do_gc(): gc.collect()