maybe fix NaNs being thrown for immature models at fp16 for training evals

This commit is contained in:
mrq 2025-02-24 18:25:54 -06:00
parent 0f39f4d7a1
commit 3330b5bb00

View File

@ -207,40 +207,42 @@ def run_eval(engines, eval_name, dl, args=None):
training=False,
)
if engine.hyper_config.version >= 7:
kwargs = base_kwargs | cfg.evaluation.kwargs
# sample for NAR demask
if random.random() < engine.hyper_config.experimental.masking_train_p:
kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ]
# inference
resps_list = engine( **kwargs )
else:
if "len" in engine.hyper_config.capabilities:
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
if engine.hyper_config.version >= 7:
kwargs = base_kwargs | cfg.evaluation.kwargs
max_steps = kwargs.pop("max_steps", 500)
if "denoise_start" in kwargs:
len_list = [ resp.shape[0] for resp in batch["resps"] ]
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
else:
len_list = engine( max_steps=5, **kwargs )
len_list = [ min( l, max_steps ) for l in len_list ]
kwargs = base_kwargs | cfg.evaluation.kwargs
resps_list = engine( **kwargs, len_list=len_list )
# sample for NAR demask
if random.random() < engine.hyper_config.experimental.masking_train_p:
kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ]
# inference
resps_list = engine( **kwargs )
else:
if "ar" in engine.hyper_config.capabilities:
if "len" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.kwargs
resps_list = engine( **kwargs )
else:
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
max_steps = kwargs.pop("max_steps", 500)
if "nar" in engine.hyper_config.capabilities:
if "denoise_start" in kwargs:
len_list = [ resp.shape[0] for resp in batch["resps"] ]
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
else:
len_list = engine( max_steps=5, **kwargs )
len_list = [ min( l, max_steps ) for l in len_list ]
kwargs = base_kwargs | cfg.evaluation.kwargs
resps_list = engine( **kwargs, resps_list=resps_list )
resps_list = engine( **kwargs, len_list=len_list )
else:
if "ar" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.kwargs
resps_list = engine( **kwargs )
else:
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
if "nar" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.kwargs
resps_list = engine( **kwargs, resps_list=resps_list )
process( name, batch, resps_list )
"""
# evaluate why it's so slow
if has_stt:
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
@ -254,6 +256,7 @@ def run_eval(engines, eval_name, dl, args=None):
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
_logger.info(f"Validation Metrics (STT): {text_list}")
"""
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
engines_stats = {