diff --git a/vall_e/train.py b/vall_e/train.py index 1573a09..89a7683 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -145,8 +145,11 @@ def run_eval(engines, eval_name, dl, args=None): if engine.hyper_config.experimental.hf: resps_list = engine( **base_kwargs ) elif "len" in engine.hyper_config.capabilities: - len_list = engine( **base_kwargs, max_steps=10 ) # don't need more than that - len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ] + kwargs = base_kwargs | cfg.evaluation.ar_kwargs + max_steps = kwargs.pop("max_steps", 500) + kwargs["max_steps"] = 10 + len_list = engine( **kwargs ) # don't need more than that + len_list = [ min( l, max_steps ) for l in len_list ] kwargs = base_kwargs | cfg.evaluation.nar_kwargs resps_list = engine( **kwargs, len_list=len_list )