some cleanup

This commit is contained in:
mrq 2024-05-25 17:46:52 -05:00
parent d760924719
commit 85f9684720
2 changed files with 6 additions and 7 deletions

View File

@ -164,6 +164,9 @@ def train():
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
def eval_fn(engines):
do_gc()
engines.eval()
# wrapped in a try block because it's sometimes prone to breaking
try:
run_eval(engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl)
@ -171,6 +174,7 @@ def train():
print("Error occurred while performing eval:", str(e))
print(traceback.format_exc())
engines.train()
qnt.unload_model()
do_gc()

View File

@ -142,9 +142,8 @@ def train(
# Pre-loop command
command = _non_blocking_input()
if command in ["eval", "eval_quit"]:
engines.eval()
eval_fn(engines=engines)
engines.train()
if command in ["quit", "eval_quit"]:
return
@ -261,12 +260,8 @@ def train(
if engines.global_step != last_eval_step:
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
do_gc()
engines.eval()
eval_fn(engines=engines)
engines.train()
last_eval_step = engines.global_step
eval_fn(engines=engines)
if command in ["quit"]:
return