From 85f96847207b62863913d38ba87571f35822593f Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 25 May 2024 17:46:52 -0500 Subject: [PATCH] some cleanup --- vall_e/train.py | 4 ++++ vall_e/utils/trainer.py | 9 ++------- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/vall_e/train.py b/vall_e/train.py index a93a6aa..7b2ea7c 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -164,6 +164,9 @@ def train(): train_dl, subtrain_dl, val_dl = create_train_val_dataloader() def eval_fn(engines): + do_gc() + engines.eval() + # wrapped in a try block because it's sometimes prone to breaking try: run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "val", val_dl) @@ -171,6 +174,7 @@ def train(): print("Error occurred while performing eval:", str(e)) print(traceback.format_exc()) + engines.train() qnt.unload_model() do_gc() diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 230e49b..fd79dcb 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -142,9 +142,8 @@ def train( # Pre-loop command command = _non_blocking_input() if command in ["eval", "eval_quit"]: - engines.eval() eval_fn(engines=engines) - engines.train() + if command in ["quit", "eval_quit"]: return @@ -261,12 +260,8 @@ def train( if engines.global_step != last_eval_step: if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]: - do_gc() - - engines.eval() - eval_fn(engines=engines) - engines.train() last_eval_step = engines.global_step + eval_fn(engines=engines) if command in ["quit"]: return