From 99ef55d605ea52d613fdeffed534194b5d185cea Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 24 Feb 2025 13:47:46 -0600 Subject: [PATCH] override eval for meme model --- vall_e/train.py | 48 ++++++++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/vall_e/train.py b/vall_e/train.py index 20cf753..0257c2a 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -176,6 +176,7 @@ def run_eval(engines, eval_name, dl, args=None): batch_size = len(batch["text"]) + """ # to-do: eval for text tasks has_stt = False for i, task in enumerate( batch["task"] ): @@ -192,6 +193,7 @@ def run_eval(engines, eval_name, dl, args=None): for i, _ in enumerate(batch["text"]): batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device) batch["resps"][i] = None + """ processed += batch_size for name in engines: @@ -205,31 +207,37 @@ def run_eval(engines, eval_name, dl, args=None): training=False, ) - if engine.hyper_config.experimental.hf: - resps_list = engine( **base_kwargs ) - elif "len" in engine.hyper_config.capabilities: + if self.version >= 7: kwargs = base_kwargs | cfg.evaluation.kwargs - max_steps = kwargs.pop("max_steps", 500) - - if "denoise_start" in kwargs: - len_list = [ resp.shape[0] for resp in batch["resps"] ] - kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ] - else: - len_list = engine( max_steps=5, **kwargs ) - len_list = [ min( l, max_steps ) for l in len_list ] - - kwargs = base_kwargs | cfg.evaluation.kwargs - resps_list = engine( **kwargs, len_list=len_list ) + # sample for NAR demask + if random.random() < cfg.model.experimental.masking_train_p: + kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ] + # inference + resps_list = engine( **kwargs ) else: - if "ar" in engine.hyper_config.capabilities: + if "len" in engine.hyper_config.capabilities: kwargs = base_kwargs | cfg.evaluation.kwargs - resps_list = engine( **kwargs ) - else: - resps_list = [ resp[:, 0] for resp in batch["resps"] ] + max_steps = kwargs.pop("max_steps", 500) - if "nar" in engine.hyper_config.capabilities: + if "denoise_start" in kwargs: + len_list = [ resp.shape[0] for resp in batch["resps"] ] + kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ] + else: + len_list = engine( max_steps=5, **kwargs ) + len_list = [ min( l, max_steps ) for l in len_list ] + kwargs = base_kwargs | cfg.evaluation.kwargs - resps_list = engine( **kwargs, resps_list=resps_list ) + resps_list = engine( **kwargs, len_list=len_list ) + else: + if "ar" in engine.hyper_config.capabilities: + kwargs = base_kwargs | cfg.evaluation.kwargs + resps_list = engine( **kwargs ) + else: + resps_list = [ resp[:, 0] for resp in batch["resps"] ] + + if "nar" in engine.hyper_config.capabilities: + kwargs = base_kwargs | cfg.evaluation.kwargs + resps_list = engine( **kwargs, resps_list=resps_list ) process( name, batch, resps_list )