From 3ff7cf83415f26efd7faa99ddb7903efdc4f5172 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 17 Aug 2023 18:56:37 -0500 Subject: [PATCH] maybe fix evaluation dataset not being capped to cfg.evaluation.size --- vall_e/data.py | 32 ++++++++++++++------------------ vall_e/train.py | 5 +++++ 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 4338250..eef2009 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -253,6 +253,10 @@ class Dataset(_Dataset): return prom + @cached_property + def tasks(self): + return ["tts"] # "ns", "sr", "tse", "cse", "nse" + def __getitem__(self, index): if hasattr(self, "sample_type") and self.sample_type == "speaker": spkr_name = self.spkrs[index] @@ -274,14 +278,17 @@ class Dataset(_Dataset): text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) resps = _load_quants(path) - # I could probably do some logic to directly use the resps, but I'm putting my faith in python aliasing - proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps + task = random.choice(self.tasks) + if task == "tts": + # I could probably do some logic to directly use the resps, but I'm putting my faith in python aliasing + proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps return dict( index=index, path=path, spkr_name=spkr_name, spkr_id=spkr_id, + task=task, text=text, proms=proms, resps=resps, @@ -360,26 +367,12 @@ def _load_dataset_paths(): paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ]) - # files = data_dir.rglob("*.qnt.pt") - #paths[type].extend([ f'/{type}{_get_hdf5_path( str(file).replace(".qnt.pt", "") )}' for file in files ]) - for data_dir in cfg.dataset.training: get_paths( data_dir, "training" ) for data_dir in cfg.dataset.validation: get_paths( data_dir, "validation" ) - """ - def process( entity ): - if "id" in entity.attrs: - paths[entity.attrs['type']].append( f"{entity.attrs['speaker']}{entity.attrs['id']}" ) - return - - for child in entity.values(): - process( child ) - """ - - for _, type in enumerate(paths): dirs = paths[type] @@ -398,6 +391,7 @@ def _load_dataset_paths(): return datasets["training"], datasets["validation"] +# to-do: mirror the hdf5-based load function def _load_train_val_paths(): paths = [] train_paths = [] @@ -432,8 +426,10 @@ def _load_train_val_paths(): train_paths, val_paths = map(sorted, [train_paths, val_paths]) if len(train_paths) == 0: - raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.dataset.training}.") - # to get it to shut up + raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.") + + # to-do: allow setting aside a fixed portion of the training dataset for validation + # something like the last X percent of each speaker is set aside if len(val_paths) == 0: val_paths = [ train_paths[0] ] diff --git a/vall_e/train.py b/vall_e/train.py index b5e4499..3799b5e 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -95,6 +95,7 @@ def run_eval(engines, eval_name, dl): stats['loss'].append(0) print(str(e)) + processed = 0 for batch in tqdm(dl): batch: dict = to_device(batch, cfg.device) @@ -131,6 +132,10 @@ def run_eval(engines, eval_name, dl): process( name, batch, resps_list ) + processed += len(batch["text"]) + if processed > cfg.evaluation.size: + break + stats = {k: sum(v) / len(v) for k, v in stats.items()} engines_stats.update(flatten_dict({ name: stats }))