override eval for meme model
This commit is contained in:
parent
f93fbf0d99
commit
99ef55d605
|
@ -176,6 +176,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
|
||||
batch_size = len(batch["text"])
|
||||
|
||||
"""
|
||||
# to-do: eval for text tasks
|
||||
has_stt = False
|
||||
for i, task in enumerate( batch["task"] ):
|
||||
|
@ -192,6 +193,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
for i, _ in enumerate(batch["text"]):
|
||||
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
|
||||
batch["resps"][i] = None
|
||||
"""
|
||||
|
||||
processed += batch_size
|
||||
for name in engines:
|
||||
|
@ -205,31 +207,37 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
training=False,
|
||||
)
|
||||
|
||||
if engine.hyper_config.experimental.hf:
|
||||
resps_list = engine( **base_kwargs )
|
||||
elif "len" in engine.hyper_config.capabilities:
|
||||
if self.version >= 7:
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
max_steps = kwargs.pop("max_steps", 500)
|
||||
|
||||
if "denoise_start" in kwargs:
|
||||
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
||||
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
||||
else:
|
||||
len_list = engine( max_steps=5, **kwargs )
|
||||
len_list = [ min( l, max_steps ) for l in len_list ]
|
||||
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs, len_list=len_list )
|
||||
# sample for NAR demask
|
||||
if random.random() < cfg.model.experimental.masking_train_p:
|
||||
kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ]
|
||||
# inference
|
||||
resps_list = engine( **kwargs )
|
||||
else:
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
if "len" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs )
|
||||
else:
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
max_steps = kwargs.pop("max_steps", 500)
|
||||
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
if "denoise_start" in kwargs:
|
||||
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
||||
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
||||
else:
|
||||
len_list = engine( max_steps=5, **kwargs )
|
||||
len_list = [ min( l, max_steps ) for l in len_list ]
|
||||
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs, resps_list=resps_list )
|
||||
resps_list = engine( **kwargs, len_list=len_list )
|
||||
else:
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs )
|
||||
else:
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs, resps_list=resps_list )
|
||||
|
||||
process( name, batch, resps_list )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user