override eval for meme model

This commit is contained in:
mrq 2025-02-24 13:47:46 -06:00
parent f93fbf0d99
commit 99ef55d605

View File

@ -176,6 +176,7 @@ def run_eval(engines, eval_name, dl, args=None):
batch_size = len(batch["text"])
"""
# to-do: eval for text tasks
has_stt = False
for i, task in enumerate( batch["task"] ):
@ -192,6 +193,7 @@ def run_eval(engines, eval_name, dl, args=None):
for i, _ in enumerate(batch["text"]):
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
batch["resps"][i] = None
"""
processed += batch_size
for name in engines:
@ -205,9 +207,15 @@ def run_eval(engines, eval_name, dl, args=None):
training=False,
)
if engine.hyper_config.experimental.hf:
resps_list = engine( **base_kwargs )
elif "len" in engine.hyper_config.capabilities:
if self.version >= 7:
kwargs = base_kwargs | cfg.evaluation.kwargs
# sample for NAR demask
if random.random() < cfg.model.experimental.masking_train_p:
kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ]
# inference
resps_list = engine( **kwargs )
else:
if "len" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.kwargs
max_steps = kwargs.pop("max_steps", 500)