support for wildcard in training/validation/noise dataset array (to-do: a better way to query between metadata folder and data folder)

This commit is contained in:
mrq 2024-09-18 21:34:43 -05:00
parent b5bec0c9ce
commit fe241f6a99
3 changed files with 31 additions and 5 deletions

View File

@ -138,7 +138,6 @@ def process(
speaker_id = metadata["speaker"]
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
if _replace_file_extension(outpath, audio_extension).exists():
continue

View File

@ -10,6 +10,7 @@ import argparse
import yaml
import random
import logging
import itertools
import torch
import numpy as np
@ -802,6 +803,27 @@ class Config(BaseConfig):
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
self.dataset.use_hdf5 = False
# a very icky way to handle wildcard expansions
def expand( self, path ):
if not isinstance( path, Path ):
path = Path(path)
# do not glob
if "*" not in str(path):
return [ path ]
metadata_parent = cfg.metadata_dir / path.parent
data_parent = cfg.data_dir / path.parent
if metadata_parent.exists():
return [ path.parent / child.stem for child in Path(metadata_parent).glob(path.name) ]
if data_parent.exists():
return [ path.parent / child.name for child in Path(data_parent).glob(path.name) ]
return path
# to-do: prune unused keys
def format( self, training=True ):
if isinstance(self.dataset, type):
@ -829,9 +851,14 @@ class Config(BaseConfig):
self.optimizations = dict()
self.dataset = Dataset(**self.dataset)
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
# convert to expanded paths
self.dataset.training = [ self.expand(dir) for dir in self.dataset.training ]
self.dataset.validation = [ self.expand(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ self.expand(dir) for dir in self.dataset.noise ]
# flatten
self.dataset.training = list(itertools.chain.from_iterable(self.dataset.training))
self.dataset.validation = list(itertools.chain.from_iterable(self.dataset.validation))
self.dataset.noise = list(itertools.chain.from_iterable(self.dataset.noise))
# do cleanup
for model in self.models:

View File

@ -1413,7 +1413,7 @@ def create_dataset_metadata( skip_existing=True ):
wrote = False
for id in tqdm(ids, desc=f"Processing {name}"):
for id in tqdm(ids, desc=f"Processing {name}", disable=True):
try:
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')