some cleanup
This commit is contained in:
parent
d760924719
commit
85f9684720
|
@ -164,6 +164,9 @@ def train():
|
|||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
def eval_fn(engines):
|
||||
do_gc()
|
||||
engines.eval()
|
||||
# wrapped in a try block because it's sometimes prone to breaking
|
||||
try:
|
||||
run_eval(engines, "subtrain", subtrain_dl)
|
||||
run_eval(engines, "val", val_dl)
|
||||
|
@ -171,6 +174,7 @@ def train():
|
|||
print("Error occurred while performing eval:", str(e))
|
||||
print(traceback.format_exc())
|
||||
|
||||
engines.train()
|
||||
qnt.unload_model()
|
||||
do_gc()
|
||||
|
||||
|
|
|
@ -142,9 +142,8 @@ def train(
|
|||
# Pre-loop command
|
||||
command = _non_blocking_input()
|
||||
if command in ["eval", "eval_quit"]:
|
||||
engines.eval()
|
||||
eval_fn(engines=engines)
|
||||
engines.train()
|
||||
|
||||
if command in ["quit", "eval_quit"]:
|
||||
return
|
||||
|
||||
|
@ -261,12 +260,8 @@ def train(
|
|||
|
||||
if engines.global_step != last_eval_step:
|
||||
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
|
||||
do_gc()
|
||||
|
||||
engines.eval()
|
||||
eval_fn(engines=engines)
|
||||
engines.train()
|
||||
last_eval_step = engines.global_step
|
||||
eval_fn(engines=engines)
|
||||
|
||||
if command in ["quit"]:
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue
Block a user