diff --git a/vall_e/train.py b/vall_e/train.py index a93a6aa..7b2ea7c 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -164,6 +164,9 @@ def train(): train_dl, subtrain_dl, val_dl = create_train_val_dataloader() def eval_fn(engines): + do_gc() + engines.eval() + # wrapped in a try block because it's sometimes prone to breaking try: run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "val", val_dl) @@ -171,6 +174,7 @@ def train(): print("Error occurred while performing eval:", str(e)) print(traceback.format_exc()) + engines.train() qnt.unload_model() do_gc() diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 230e49b..fd79dcb 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -142,9 +142,8 @@ def train( # Pre-loop command command = _non_blocking_input() if command in ["eval", "eval_quit"]: - engines.eval() eval_fn(engines=engines) - engines.train() + if command in ["quit", "eval_quit"]: return @@ -261,12 +260,8 @@ def train( if engines.global_step != last_eval_step: if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]: - do_gc() - - engines.eval() - eval_fn(engines=engines) - engines.train() last_eval_step = engines.global_step + eval_fn(engines=engines) if command in ["quit"]: return