use homwbrewed caching system for dataloader paths / durations (I'm pretty sure I am now triggering OOM killers with my entire dataset used)
This commit is contained in:
parent
a748e223ce
commit
cf9df71f2c
|
@ -22,7 +22,7 @@ from pathlib import Path
|
|||
|
||||
from .utils.distributed import world_size
|
||||
from .utils.io import torch_load
|
||||
from .utils import set_seed, prune_missing
|
||||
from .utils import set_seed, prune_missing, md5_hash
|
||||
|
||||
@dataclass()
|
||||
class BaseConfig:
|
||||
|
@ -200,6 +200,9 @@ class Dataset:
|
|||
|
||||
_frames_per_second: int = 0 # allows setting your own hint
|
||||
|
||||
def hash_key(self, *args):
|
||||
return md5_hash([ self.use_hdf5, self.min_duration, self.max_duration ] + [*args])
|
||||
|
||||
@cached_property
|
||||
def frames_per_second(self):
|
||||
if self._frames_per_second > 0:
|
||||
|
|
|
@ -15,7 +15,7 @@ from .config import cfg
|
|||
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt, pad_codes_with_silence
|
||||
from .emb.g2p import encode as encode_phns
|
||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||
from .utils.distributed import global_rank, local_rank, world_size
|
||||
from .utils.distributed import global_rank, local_rank, world_size, is_global_leader
|
||||
from .utils.io import torch_save, torch_load, json_read, json_write, json_stringify, json_parse
|
||||
from .utils import setup_logging
|
||||
|
||||
|
@ -531,14 +531,36 @@ def _get_phone_path(path):
|
|||
return _replace_file_extension(path, _get_phone_extension())
|
||||
|
||||
_durations_map = {}
|
||||
# makeshift caching the above to disk
|
||||
@cfg.diskcache()
|
||||
def _get_duration_map( type="training" ):
|
||||
return _durations_map[type] if type in _durations_map else {}
|
||||
|
||||
@cfg.diskcache()
|
||||
def _load_paths(dataset, type="training", silent=False):
|
||||
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
|
||||
cached_dir = cfg.cache_dir / cfg.dataset.hash_key(sorted(dataset))
|
||||
|
||||
cached_durations_path = cached_dir / f"durations[{type}].json"
|
||||
cached_paths_path = cached_dir / f"dataloader[{type}].json"
|
||||
|
||||
# load the duration table first, since this is independent from the loaded paths
|
||||
if cached_durations_path.exists():
|
||||
_durations_map[type] = json_read( cached_durations_path )
|
||||
|
||||
# load the cached valid paths (if we're requesting cache use)
|
||||
if cached_paths_path.exists() and cfg.dataset.cache:
|
||||
# to-do: automatic conversion between HDF5 formatted paths and on-disk paths
|
||||
return json_read( cached_paths_path )
|
||||
|
||||
# deduce valid paths
|
||||
paths = { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
|
||||
|
||||
# and write if global leader (to avoid other processes writing to the same file at once)
|
||||
if is_global_leader():
|
||||
if not cached_dir.exists():
|
||||
cached_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
json_write( _durations_map[type], cached_durations_path, truncate=True )
|
||||
json_write( paths, cached_paths_path, truncate=True )
|
||||
|
||||
return paths
|
||||
|
||||
def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||
data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name
|
||||
|
@ -685,6 +707,10 @@ class Dataset(_Dataset):
|
|||
|
||||
# dict of paths keyed by speaker names
|
||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
||||
# do it here due to the above
|
||||
self.duration = 0
|
||||
self.duration_map = _get_duration_map( self.dataset_type )
|
||||
self.duration_buckets = {}
|
||||
|
||||
# cull speakers if they do not have enough utterances
|
||||
if cfg.dataset.min_utterances > 0:
|
||||
|
@ -716,11 +742,6 @@ class Dataset(_Dataset):
|
|||
self.paths_by_spkr_name[name] = []
|
||||
self.paths_by_spkr_name[name].append( path )
|
||||
|
||||
# do it here due to the above
|
||||
self.duration = 0
|
||||
self.duration_map = _get_duration_map( self.dataset_type )
|
||||
self.duration_buckets = {}
|
||||
|
||||
# store in corresponding bucket
|
||||
for path in self.paths:
|
||||
duration = self.duration_map[path]
|
||||
|
@ -759,7 +780,9 @@ class Dataset(_Dataset):
|
|||
# just interleave
|
||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||
|
||||
|
||||
# dereference buckets
|
||||
self.duration_map = None
|
||||
self.duration_buckets = None
|
||||
|
||||
# dict of speakers keyed by speaker group
|
||||
self.spkrs_by_spkr_group = {}
|
||||
|
@ -1414,6 +1437,9 @@ def create_train_dataloader():
|
|||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
||||
|
||||
# remove duration map (it gets bloated)
|
||||
_durations_map = {}
|
||||
|
||||
return train_dl
|
||||
|
||||
def create_val_dataloader():
|
||||
|
@ -1427,8 +1453,12 @@ def create_val_dataloader():
|
|||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
||||
|
||||
# remove duration map (it gets bloated)
|
||||
_durations_map = {}
|
||||
|
||||
return val_dl
|
||||
|
||||
# to-do, use the above two, then create the subtrain dataset
|
||||
def create_train_val_dataloader():
|
||||
train_dataset, val_dataset = create_datasets()
|
||||
|
||||
|
@ -1456,6 +1486,9 @@ def create_train_val_dataloader():
|
|||
|
||||
assert isinstance(subtrain_dl.dataset, Dataset)
|
||||
|
||||
# remove duration map (it gets bloated)
|
||||
_durations_map = {}
|
||||
|
||||
return train_dl, subtrain_dl, val_dl
|
||||
|
||||
# parse metadata from an numpy file (.enc/.dac) and validate it
|
||||
|
@ -1506,7 +1539,7 @@ def remap_speaker_name( name ):
|
|||
return name
|
||||
|
||||
# parse dataset into better to sample metadata
|
||||
def create_dataset_metadata( skip_existing=True ):
|
||||
def create_dataset_metadata( skip_existing=False ):
|
||||
symmap = get_phone_symmap()
|
||||
|
||||
root = str(cfg.data_dir)
|
||||
|
@ -1557,6 +1590,8 @@ def create_dataset_metadata( skip_existing=True ):
|
|||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||
|
||||
utterance_metadata = process_artifact_metadata( artifact )
|
||||
# to-do: derive duration from codes if duration is malformed because this happened to me with LibriTTS-R
|
||||
#utterance_metadata["duration"] = qnt.shape[0] / cfg.dataset.frames_per_second
|
||||
|
||||
for k, v in utterance_metadata.items():
|
||||
metadata[id][k] = v
|
||||
|
|
|
@ -13,5 +13,6 @@ from .utils import (
|
|||
truncate_json,
|
||||
timer,
|
||||
prune_missing,
|
||||
clamp
|
||||
clamp,
|
||||
md5_hash
|
||||
)
|
|
@ -15,6 +15,7 @@ import time
|
|||
import psutil
|
||||
import math
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -31,6 +32,11 @@ from datetime import datetime
|
|||
|
||||
T = TypeVar("T")
|
||||
|
||||
def md5_hash( x ):
|
||||
if isinstance( x, list ):
|
||||
return md5_hash(":".join([ md5_hash( _ ) for _ in x ]))
|
||||
return hashlib.md5(str(x).encode("utf-8")).hexdigest()
|
||||
|
||||
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ):
|
||||
is_obj = hasattr( source, "__dict__" )
|
||||
if parent_is_obj is None:
|
||||
|
@ -69,12 +75,14 @@ class timer:
|
|||
|
||||
print(f'[{datetime.now().isoformat()}] {msg}')
|
||||
|
||||
def truncate_json( str ):
|
||||
def truncate_json( x ):
|
||||
if isinstance( x, bytes ):
|
||||
return truncate_json( x.decode('utf-8') ).encode()
|
||||
|
||||
def fun( match ):
|
||||
return "{:.4f}".format(float(match.group()))
|
||||
|
||||
return re.sub(r"\d+\.\d{8,}", fun, str)
|
||||
return re.sub(r"\d+\.\d{8,}", fun, x)
|
||||
|
||||
def do_gc():
|
||||
gc.collect()
|
||||
|
|
Loading…
Reference in New Issue
Block a user