diff --git a/vall_e/train.py b/vall_e/train.py index 4c2943d..e6d5a6e 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -207,40 +207,42 @@ def run_eval(engines, eval_name, dl, args=None): training=False, ) - if engine.hyper_config.version >= 7: - kwargs = base_kwargs | cfg.evaluation.kwargs - # sample for NAR demask - if random.random() < engine.hyper_config.experimental.masking_train_p: - kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ] - # inference - resps_list = engine( **kwargs ) - else: - if "len" in engine.hyper_config.capabilities: + with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): + if engine.hyper_config.version >= 7: kwargs = base_kwargs | cfg.evaluation.kwargs - max_steps = kwargs.pop("max_steps", 500) - - if "denoise_start" in kwargs: - len_list = [ resp.shape[0] for resp in batch["resps"] ] - kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ] - else: - len_list = engine( max_steps=5, **kwargs ) - len_list = [ min( l, max_steps ) for l in len_list ] - - kwargs = base_kwargs | cfg.evaluation.kwargs - resps_list = engine( **kwargs, len_list=len_list ) + # sample for NAR demask + if random.random() < engine.hyper_config.experimental.masking_train_p: + kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ] + # inference + resps_list = engine( **kwargs ) else: - if "ar" in engine.hyper_config.capabilities: + if "len" in engine.hyper_config.capabilities: kwargs = base_kwargs | cfg.evaluation.kwargs - resps_list = engine( **kwargs ) - else: - resps_list = [ resp[:, 0] for resp in batch["resps"] ] + max_steps = kwargs.pop("max_steps", 500) - if "nar" in engine.hyper_config.capabilities: + if "denoise_start" in kwargs: + len_list = [ resp.shape[0] for resp in batch["resps"] ] + kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ] + else: + len_list = engine( max_steps=5, **kwargs ) + len_list = [ min( l, max_steps ) for l in len_list ] + kwargs = base_kwargs | cfg.evaluation.kwargs - resps_list = engine( **kwargs, resps_list=resps_list ) + resps_list = engine( **kwargs, len_list=len_list ) + else: + if "ar" in engine.hyper_config.capabilities: + kwargs = base_kwargs | cfg.evaluation.kwargs + resps_list = engine( **kwargs ) + else: + resps_list = [ resp[:, 0] for resp in batch["resps"] ] + + if "nar" in engine.hyper_config.capabilities: + kwargs = base_kwargs | cfg.evaluation.kwargs + resps_list = engine( **kwargs, resps_list=resps_list ) process( name, batch, resps_list ) + """ # evaluate why it's so slow if has_stt: max_steps = max( [ text.shape[0] for text in batch["text"] ] ) @@ -254,6 +256,7 @@ def run_eval(engines, eval_name, dl, args=None): text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ] _logger.info(f"Validation Metrics (STT): {text_list}") + """ stats = {k: sum(v) / len(v) for k, v in stats.items() if v} engines_stats = {