From d606a693ffe9f7cf0550ae1823306efa9f264728 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 6 Nov 2024 23:14:16 -0600 Subject: [PATCH] eval fix for nar-len --- vall_e/train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vall_e/train.py b/vall_e/train.py index 1573a09..89a7683 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -145,8 +145,11 @@ def run_eval(engines, eval_name, dl, args=None): if engine.hyper_config.experimental.hf: resps_list = engine( **base_kwargs ) elif "len" in engine.hyper_config.capabilities: - len_list = engine( **base_kwargs, max_steps=10 ) # don't need more than that - len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ] + kwargs = base_kwargs | cfg.evaluation.ar_kwargs + max_steps = kwargs.pop("max_steps", 500) + kwargs["max_steps"] = 10 + len_list = engine( **kwargs ) # don't need more than that + len_list = [ min( l, max_steps ) for l in len_list ] kwargs = base_kwargs | cfg.evaluation.nar_kwargs resps_list = engine( **kwargs, len_list=len_list )