override eval for meme model
This commit is contained in:
parent
f93fbf0d99
commit
99ef55d605
|
@ -176,6 +176,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
|
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
|
|
||||||
|
"""
|
||||||
# to-do: eval for text tasks
|
# to-do: eval for text tasks
|
||||||
has_stt = False
|
has_stt = False
|
||||||
for i, task in enumerate( batch["task"] ):
|
for i, task in enumerate( batch["task"] ):
|
||||||
|
@ -192,6 +193,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
for i, _ in enumerate(batch["text"]):
|
for i, _ in enumerate(batch["text"]):
|
||||||
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
|
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
|
||||||
batch["resps"][i] = None
|
batch["resps"][i] = None
|
||||||
|
"""
|
||||||
|
|
||||||
processed += batch_size
|
processed += batch_size
|
||||||
for name in engines:
|
for name in engines:
|
||||||
|
@ -205,9 +207,15 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
training=False,
|
training=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if engine.hyper_config.experimental.hf:
|
if self.version >= 7:
|
||||||
resps_list = engine( **base_kwargs )
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
elif "len" in engine.hyper_config.capabilities:
|
# sample for NAR demask
|
||||||
|
if random.random() < cfg.model.experimental.masking_train_p:
|
||||||
|
kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ]
|
||||||
|
# inference
|
||||||
|
resps_list = engine( **kwargs )
|
||||||
|
else:
|
||||||
|
if "len" in engine.hyper_config.capabilities:
|
||||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
max_steps = kwargs.pop("max_steps", 500)
|
max_steps = kwargs.pop("max_steps", 500)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user