maybe fix NaNs being thrown for immature models at fp16 for training evals

This commit is contained in:
mrq 2025-02-24 18:25:54 -06:00
parent 0f39f4d7a1
commit 3330b5bb00

View File

@ -207,6 +207,7 @@ def run_eval(engines, eval_name, dl, args=None):
training=False,
)
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
if engine.hyper_config.version >= 7:
kwargs = base_kwargs | cfg.evaluation.kwargs
# sample for NAR demask
@ -241,6 +242,7 @@ def run_eval(engines, eval_name, dl, args=None):
process( name, batch, resps_list )
"""
# evaluate why it's so slow
if has_stt:
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
@ -254,6 +256,7 @@ def run_eval(engines, eval_name, dl, args=None):
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
_logger.info(f"Validation Metrics (STT): {text_list}")
"""
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
engines_stats = {