From fa93061b3eeb7e77ea660917782a9262ba0b6776 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 6 Sep 2024 16:59:56 -0500 Subject: [PATCH] more fixes, moved sampler state dict to a better place, eval works again --- vall_e/data.py | 2 +- vall_e/train.py | 26 +++++++++++++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 8680619..6115b7c 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -734,7 +734,7 @@ class Dataset(_Dataset): @cached_property def sampler_state_dict_path(self): - return cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" + return cfg.ckpt_dir / cfg.model.full_name / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" def get_speaker(self, path): if isinstance(path, str): diff --git a/vall_e/train.py b/vall_e/train.py index 1f76dea..7c7cd29 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -103,19 +103,25 @@ def run_eval(engines, eval_name, dl): for key in batch.keys(): batch[key] = batch[key][:cfg.evaluation.batch_size] - processed += len(batch["text"]) + batch_size = len(batch["text"]) + + processed += batch_size for name in engines: engine = engines[name] # to-do: eval for text tasks - for i, task in batch["task"]: + has_stt = False + for i, task in enumerate( batch["task"] ): + # easier to just change it to a tts task than drop stt tasks from the batch if task == "stt": + has_stt = True batch["task"][i] = "tts" + batch["proms"][i] = batch["resps"][i][:75*3, :] kwargs = dict( text_list=batch["text"], - prom_list=batch["proms"], + proms_list=batch["proms"], lang_list=batch["lang"], task_list=batch["task"], ) @@ -137,6 +143,20 @@ def run_eval(engines, eval_name, dl): process( name, batch, resps_list ) + # evaluate why it's so slow + if has_stt: + max_steps = max( [ text.shape[0] for text in batch["text"] ] ) + + kwargs["text_list"] = None + kwargs["task_list"] = [ "stt" for _ in range(batch_size) ] + kwargs["proms_list"] = [ ["stt"] for _ in range(batch_size) ] + kwargs["resps_list"] = batch["resps"] + + text_list = engine( **kwargs, max_steps=max_steps, sampling_temperature=0.0) + text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ] + + _logger.info(f"Validation Metrics (STT): {text_list}") + stats = {k: sum(v) / len(v) for k, v in stats.items()} engines_stats = { f'{name}.{eval_name}': stats,