add cap for NAR-len training, to avoid any weird cases in early training where it'll just mess up and generate long lengths

This commit is contained in:
mrq 2024-08-03 21:00:32 -05:00
parent 4d2b88b164
commit ab673e0426
2 changed files with 8 additions and 8 deletions

View File

@ -112,6 +112,7 @@ def run_eval(engines, eval_name, dl):
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"] ) resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"] )
elif "len" in engine.hyper_config.capabilities: elif "len" in engine.hyper_config.capabilities:
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that
len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ]
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels ) resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels )
else: else:
if "ar" in engine.hyper_config.capabilities: if "ar" in engine.hyper_config.capabilities:

View File

@ -127,14 +127,13 @@ def train(
engines = load_engines() engines = load_engines()
# validate if there's at least one model to train # validate if there's at least one model to train
if training: found = False
found = False for name, engine in engines.items():
for name, engine in engines.items(): if engine.training:
if engine.training: found = True
found = True break
break if not found:
if not found: raise Exception('Training, but no model loaded set to train...')
raise Exception('Training, but no model loaded set to train...')
""" """
if is_local_leader(): if is_local_leader():