diff --git a/vall_e/config.py b/vall_e/config.py index ca3cafa..9a600d5 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -884,18 +884,29 @@ class Config(BaseConfig): if not isinstance( path, Path ): path = Path(path) - # do not glob + # do not glob if no wildcard to glob if "*" not in str(path): return [ path ] - - metadata_parent = cfg.metadata_dir / path.parent - data_parent = cfg.data_dir / path.parent - - if metadata_parent.exists(): - return [ path.parent / child.stem for child in Path(metadata_parent).glob(path.name) ] + dir = path.parent + name = path.name + + metadata_parent = cfg.metadata_dir / dir + data_parent = cfg.data_dir / dir + + res = [] + # grab any paths from metadata folder (since this is for HDF5) + if metadata_parent.exists(): + res = [ path.parent / child.stem for child in Path(metadata_parent).glob(name) ] + # return if found anything + if res: + return res + # grab anything from the data folder (if no metadata exists) if data_parent.exists(): - return [ path.parent / child.name for child in Path(data_parent).glob(path.name) ] + res = [ path.parent / child.name for child in Path(data_parent).glob(name) ] + # return if found anything + if res: + return res # return an empty list if self.silent_errors: diff --git a/vall_e/data.py b/vall_e/data.py index 08bbf79..79c2355 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -633,8 +633,6 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): phones = entry['phones'] if "phones" in entry else 0 duration = entry['duration'] if "duration" in entry else 0 - #print( id, duration ) - # add to duration bucket k = key(id, entry) if type not in _durations_map: @@ -1079,7 +1077,7 @@ class Dataset(_Dataset): def sample_prompts(self, spkr_name, reference, should_trim=True): # return no prompt if explicitly requested for who knows why # or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them) - if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0 or len(self.paths_by_spkr_name[key]) <= 1: + if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0 or len(self.paths_by_spkr_name[spkr_name]) <= 1: return None prom_list = [] @@ -1686,7 +1684,11 @@ def create_dataset_hdf5( skip_existing=True ): metadata_path = Path(f"{metadata_root}/{speaker_name}.json") metadata_path.parents[0].mkdir(parents=True, exist_ok=True) - metadata = json_read(metadata_path, default={}) + try: + metadata = json_read(metadata_path, default={}) + except Exception as e: + print(metadata_path, e) + return if not os.path.isdir(f'{root}/{name}/'): return diff --git a/vall_e/demo.py b/vall_e/demo.py index 13d4160..38a17b7 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -302,8 +302,8 @@ def main(): num = args.dataset_samples if args.dataset_samples else length for i in trange( num, desc="Sampling dataset for samples" ): - index = i if not cfg.dataset.sample_shuffle else random.randint( 0, len( dataloader.dataset ) ) - batch = dataloader.dataset[i] + index = i if not cfg.dataset.sample_shuffle else random.randint( 0, len( dataloader.dataset ) - 1 ) + batch = dataloader.dataset[index] if args.dataset_dir_name_prefix: dir = args.demo_dir / args.dataset_dir_name / f'{args.dataset_dir_name_prefix}_{i}' diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ec2d611..65b5cd1 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -961,6 +961,11 @@ def example_usage(): learning_rate = 1.0e-4 optimizer = ml.SGD + elif optimizer == "apollo": + if learning_rate is None: + learning_rate = 0.01 + + optimizer = ml.Apollo else: raise ValueError(f"Unrecognized optimizer: {optimizer}")