support for wildcard in training/validation/noise dataset array (to-do: a better way to query between metadata folder and data folder)
This commit is contained in:
parent
b5bec0c9ce
commit
fe241f6a99
|
@ -138,7 +138,6 @@ def process(
|
||||||
speaker_id = metadata["speaker"]
|
speaker_id = metadata["speaker"]
|
||||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
|
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
|
||||||
|
|
||||||
|
|
||||||
if _replace_file_extension(outpath, audio_extension).exists():
|
if _replace_file_extension(outpath, audio_extension).exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import argparse
|
||||||
import yaml
|
import yaml
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
|
import itertools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -802,6 +803,27 @@ class Config(BaseConfig):
|
||||||
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
|
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
|
||||||
self.dataset.use_hdf5 = False
|
self.dataset.use_hdf5 = False
|
||||||
|
|
||||||
|
# a very icky way to handle wildcard expansions
|
||||||
|
def expand( self, path ):
|
||||||
|
if not isinstance( path, Path ):
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
# do not glob
|
||||||
|
if "*" not in str(path):
|
||||||
|
return [ path ]
|
||||||
|
|
||||||
|
metadata_parent = cfg.metadata_dir / path.parent
|
||||||
|
data_parent = cfg.data_dir / path.parent
|
||||||
|
|
||||||
|
if metadata_parent.exists():
|
||||||
|
return [ path.parent / child.stem for child in Path(metadata_parent).glob(path.name) ]
|
||||||
|
|
||||||
|
if data_parent.exists():
|
||||||
|
return [ path.parent / child.name for child in Path(data_parent).glob(path.name) ]
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
# to-do: prune unused keys
|
# to-do: prune unused keys
|
||||||
def format( self, training=True ):
|
def format( self, training=True ):
|
||||||
if isinstance(self.dataset, type):
|
if isinstance(self.dataset, type):
|
||||||
|
@ -829,9 +851,14 @@ class Config(BaseConfig):
|
||||||
self.optimizations = dict()
|
self.optimizations = dict()
|
||||||
|
|
||||||
self.dataset = Dataset(**self.dataset)
|
self.dataset = Dataset(**self.dataset)
|
||||||
self.dataset.training = [ Path(dir) for dir in self.dataset.training ]
|
# convert to expanded paths
|
||||||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
self.dataset.training = [ self.expand(dir) for dir in self.dataset.training ]
|
||||||
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
self.dataset.validation = [ self.expand(dir) for dir in self.dataset.validation ]
|
||||||
|
self.dataset.noise = [ self.expand(dir) for dir in self.dataset.noise ]
|
||||||
|
# flatten
|
||||||
|
self.dataset.training = list(itertools.chain.from_iterable(self.dataset.training))
|
||||||
|
self.dataset.validation = list(itertools.chain.from_iterable(self.dataset.validation))
|
||||||
|
self.dataset.noise = list(itertools.chain.from_iterable(self.dataset.noise))
|
||||||
|
|
||||||
# do cleanup
|
# do cleanup
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
|
|
|
@ -1413,7 +1413,7 @@ def create_dataset_metadata( skip_existing=True ):
|
||||||
|
|
||||||
wrote = False
|
wrote = False
|
||||||
|
|
||||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
for id in tqdm(ids, desc=f"Processing {name}", disable=True):
|
||||||
try:
|
try:
|
||||||
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
|
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user