From 0aa2a3cc0770823bf1c22b291fe0eab2c9949de8 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 29 Oct 2023 12:00:40 -0500 Subject: [PATCH] evaluation/validation passes language ID during training (oops) --- vall_e/data.py | 2 +- vall_e/train.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index c7d1651..c2210bc 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -224,7 +224,7 @@ class Dataset(_Dataset): self.spkrs_by_spkr_group[spkr_group].append( spkr ) self.spkr_groups = list(self.spkrs_by_spkr_group.keys()) - + self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() } if self.sampler_type == "path": diff --git a/vall_e/train.py b/vall_e/train.py index dda58da..ac57662 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -103,9 +103,9 @@ def run_eval(engines, eval_name, dl): if AR is not None and NAR is not None: name = "+".join(names) - resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) + resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) resps_list = [ r.unsqueeze(-1) for r in resps_list ] - resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) + resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) process( name, batch, resps_list ) else: @@ -113,13 +113,14 @@ def run_eval(engines, eval_name, dl): model = engines[name] if name.startswith("ar+nar"): - resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) + resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) resps_list = [ r.unsqueeze(-1) for r in resps_list ] - resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) + resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) elif name.startswith("ar"): resps_list = model( text_list=batch["text"], proms_list=batch["proms"], + lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature, ) @@ -128,6 +129,7 @@ def run_eval(engines, eval_name, dl): resps_list = model( text_list=batch["text"], proms_list=batch["proms"], + lang_list=batch["lang"], resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]], sampling_temperature=cfg.evaluation.nar_temperature, )