|
|
|
@ -103,9 +103,9 @@ def run_eval(engines, eval_name, dl):
|
|
|
|
|
if AR is not None and NAR is not None:
|
|
|
|
|
name = "+".join(names)
|
|
|
|
|
|
|
|
|
|
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
|
|
|
|
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
|
|
|
|
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
|
|
|
|
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
|
|
|
|
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
|
|
|
|
|
|
|
|
|
process( name, batch, resps_list )
|
|
|
|
|
else:
|
|
|
|
@ -113,13 +113,14 @@ def run_eval(engines, eval_name, dl):
|
|
|
|
|
model = engines[name]
|
|
|
|
|
|
|
|
|
|
if name.startswith("ar+nar"):
|
|
|
|
|
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
|
|
|
|
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
|
|
|
|
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
|
|
|
|
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
|
|
|
|
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
|
|
|
|
elif name.startswith("ar"):
|
|
|
|
|
resps_list = model(
|
|
|
|
|
text_list=batch["text"],
|
|
|
|
|
proms_list=batch["proms"],
|
|
|
|
|
lang_list=batch["lang"],
|
|
|
|
|
max_steps=cfg.evaluation.steps,
|
|
|
|
|
sampling_temperature=cfg.evaluation.ar_temperature,
|
|
|
|
|
)
|
|
|
|
@ -128,6 +129,7 @@ def run_eval(engines, eval_name, dl):
|
|
|
|
|
resps_list = model(
|
|
|
|
|
text_list=batch["text"],
|
|
|
|
|
proms_list=batch["proms"],
|
|
|
|
|
lang_list=batch["lang"],
|
|
|
|
|
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
|
|
|
|
|
sampling_temperature=cfg.evaluation.nar_temperature,
|
|
|
|
|
)
|
|
|
|
|