override eval for meme model

This commit is contained in:
mrq 2025-02-24 13:47:46 -06:00
parent f93fbf0d99
commit 99ef55d605

View File

@ -176,6 +176,7 @@ def run_eval(engines, eval_name, dl, args=None):
batch_size = len(batch["text"])
"""
# to-do: eval for text tasks
has_stt = False
for i, task in enumerate( batch["task"] ):
@ -192,6 +193,7 @@ def run_eval(engines, eval_name, dl, args=None):
for i, _ in enumerate(batch["text"]):
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
batch["resps"][i] = None
"""
processed += batch_size
for name in engines:
@ -205,31 +207,37 @@ def run_eval(engines, eval_name, dl, args=None):
training=False,
)
if engine.hyper_config.experimental.hf:
resps_list = engine( **base_kwargs )
elif "len" in engine.hyper_config.capabilities:
if self.version >= 7:
kwargs = base_kwargs | cfg.evaluation.kwargs
max_steps = kwargs.pop("max_steps", 500)
if "denoise_start" in kwargs:
len_list = [ resp.shape[0] for resp in batch["resps"] ]
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
else:
len_list = engine( max_steps=5, **kwargs )
len_list = [ min( l, max_steps ) for l in len_list ]
kwargs = base_kwargs | cfg.evaluation.kwargs
resps_list = engine( **kwargs, len_list=len_list )
# sample for NAR demask
if random.random() < cfg.model.experimental.masking_train_p:
kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ]
# inference
resps_list = engine( **kwargs )
else:
if "ar" in engine.hyper_config.capabilities:
if "len" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.kwargs
resps_list = engine( **kwargs )
else:
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
max_steps = kwargs.pop("max_steps", 500)
if "nar" in engine.hyper_config.capabilities:
if "denoise_start" in kwargs:
len_list = [ resp.shape[0] for resp in batch["resps"] ]
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
else:
len_list = engine( max_steps=5, **kwargs )
len_list = [ min( l, max_steps ) for l in len_list ]
kwargs = base_kwargs | cfg.evaluation.kwargs
resps_list = engine( **kwargs, resps_list=resps_list )
resps_list = engine( **kwargs, len_list=len_list )
else:
if "ar" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.kwargs
resps_list = engine( **kwargs )
else:
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
if "nar" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.kwargs
resps_list = engine( **kwargs, resps_list=resps_list )
process( name, batch, resps_list )