diff --git a/vall_e/train.py b/vall_e/train.py index 0257c2a..4c2943d 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -207,10 +207,10 @@ def run_eval(engines, eval_name, dl, args=None): training=False, ) - if self.version >= 7: + if engine.hyper_config.version >= 7: kwargs = base_kwargs | cfg.evaluation.kwargs # sample for NAR demask - if random.random() < cfg.model.experimental.masking_train_p: + if random.random() < engine.hyper_config.experimental.masking_train_p: kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ] # inference resps_list = engine( **kwargs )