use homwbrewed caching system for dataloader paths / durations (I'm pretty sure I am now triggering OOM killers with my entire dataset used)

This commit is contained in:
mrq 2024-11-11 16:32:08 -06:00
parent a748e223ce
commit cf9df71f2c
4 changed files with 63 additions and 16 deletions

View File

@ -22,7 +22,7 @@ from pathlib import Path
from .utils.distributed import world_size
from .utils.io import torch_load
from .utils import set_seed, prune_missing
from .utils import set_seed, prune_missing, md5_hash
@dataclass()
class BaseConfig:
@ -200,6 +200,9 @@ class Dataset:
_frames_per_second: int = 0 # allows setting your own hint
def hash_key(self, *args):
return md5_hash([ self.use_hdf5, self.min_duration, self.max_duration ] + [*args])
@cached_property
def frames_per_second(self):
if self._frames_per_second > 0:

View File

@ -15,7 +15,7 @@ from .config import cfg
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt, pad_codes_with_silence
from .emb.g2p import encode as encode_phns
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size
from .utils.distributed import global_rank, local_rank, world_size, is_global_leader
from .utils.io import torch_save, torch_load, json_read, json_write, json_stringify, json_parse
from .utils import setup_logging
@ -531,14 +531,36 @@ def _get_phone_path(path):
return _replace_file_extension(path, _get_phone_extension())
_durations_map = {}
# makeshift caching the above to disk
@cfg.diskcache()
def _get_duration_map( type="training" ):
return _durations_map[type] if type in _durations_map else {}
@cfg.diskcache()
def _load_paths(dataset, type="training", silent=False):
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
cached_dir = cfg.cache_dir / cfg.dataset.hash_key(sorted(dataset))
cached_durations_path = cached_dir / f"durations[{type}].json"
cached_paths_path = cached_dir / f"dataloader[{type}].json"
# load the duration table first, since this is independent from the loaded paths
if cached_durations_path.exists():
_durations_map[type] = json_read( cached_durations_path )
# load the cached valid paths (if we're requesting cache use)
if cached_paths_path.exists() and cfg.dataset.cache:
# to-do: automatic conversion between HDF5 formatted paths and on-disk paths
return json_read( cached_paths_path )
# deduce valid paths
paths = { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
# and write if global leader (to avoid other processes writing to the same file at once)
if is_global_leader():
if not cached_dir.exists():
cached_dir.mkdir(parents=True, exist_ok=True)
json_write( _durations_map[type], cached_durations_path, truncate=True )
json_write( paths, cached_paths_path, truncate=True )
return paths
def _load_paths_from_metadata(group_name, type="training", validate=False):
data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name
@ -685,6 +707,10 @@ class Dataset(_Dataset):
# dict of paths keyed by speaker names
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
# do it here due to the above
self.duration = 0
self.duration_map = _get_duration_map( self.dataset_type )
self.duration_buckets = {}
# cull speakers if they do not have enough utterances
if cfg.dataset.min_utterances > 0:
@ -716,11 +742,6 @@ class Dataset(_Dataset):
self.paths_by_spkr_name[name] = []
self.paths_by_spkr_name[name].append( path )
# do it here due to the above
self.duration = 0
self.duration_map = _get_duration_map( self.dataset_type )
self.duration_buckets = {}
# store in corresponding bucket
for path in self.paths:
duration = self.duration_map[path]
@ -759,7 +780,9 @@ class Dataset(_Dataset):
# just interleave
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
# dereference buckets
self.duration_map = None
self.duration_buckets = None
# dict of speakers keyed by speaker group
self.spkrs_by_spkr_group = {}
@ -1414,6 +1437,9 @@ def create_train_dataloader():
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
# remove duration map (it gets bloated)
_durations_map = {}
return train_dl
def create_val_dataloader():
@ -1427,8 +1453,12 @@ def create_val_dataloader():
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
# remove duration map (it gets bloated)
_durations_map = {}
return val_dl
# to-do, use the above two, then create the subtrain dataset
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
@ -1456,6 +1486,9 @@ def create_train_val_dataloader():
assert isinstance(subtrain_dl.dataset, Dataset)
# remove duration map (it gets bloated)
_durations_map = {}
return train_dl, subtrain_dl, val_dl
# parse metadata from an numpy file (.enc/.dac) and validate it
@ -1506,7 +1539,7 @@ def remap_speaker_name( name ):
return name
# parse dataset into better to sample metadata
def create_dataset_metadata( skip_existing=True ):
def create_dataset_metadata( skip_existing=False ):
symmap = get_phone_symmap()
root = str(cfg.data_dir)
@ -1557,6 +1590,8 @@ def create_dataset_metadata( skip_existing=True ):
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16)
utterance_metadata = process_artifact_metadata( artifact )
# to-do: derive duration from codes if duration is malformed because this happened to me with LibriTTS-R
#utterance_metadata["duration"] = qnt.shape[0] / cfg.dataset.frames_per_second
for k, v in utterance_metadata.items():
metadata[id][k] = v

View File

@ -13,5 +13,6 @@ from .utils import (
truncate_json,
timer,
prune_missing,
clamp
clamp,
md5_hash
)

View File

@ -15,6 +15,7 @@ import time
import psutil
import math
import logging
import hashlib
_logger = logging.getLogger(__name__)
@ -31,6 +32,11 @@ from datetime import datetime
T = TypeVar("T")
def md5_hash( x ):
if isinstance( x, list ):
return md5_hash(":".join([ md5_hash( _ ) for _ in x ]))
return hashlib.md5(str(x).encode("utf-8")).hexdigest()
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ):
is_obj = hasattr( source, "__dict__" )
if parent_is_obj is None:
@ -69,12 +75,14 @@ class timer:
print(f'[{datetime.now().isoformat()}] {msg}')
def truncate_json( str ):
def truncate_json( x ):
if isinstance( x, bytes ):
return truncate_json( x.decode('utf-8') ).encode()
def fun( match ):
return "{:.4f}".format(float(match.group()))
return re.sub(r"\d+\.\d{8,}", fun, str)
return re.sub(r"\d+\.\d{8,}", fun, x)
def do_gc():
gc.collect()