maybe fix NaNs being thrown for immature models at fp16 for training evals

This commit is contained in:
mrq 2025-02-24 18:25:54 -06:00
parent 0f39f4d7a1
commit 3330b5bb00

View File

@ -207,6 +207,7 @@ def run_eval(engines, eval_name, dl, args=None):
training=False, training=False,
) )
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
if engine.hyper_config.version >= 7: if engine.hyper_config.version >= 7:
kwargs = base_kwargs | cfg.evaluation.kwargs kwargs = base_kwargs | cfg.evaluation.kwargs
# sample for NAR demask # sample for NAR demask
@ -241,6 +242,7 @@ def run_eval(engines, eval_name, dl, args=None):
process( name, batch, resps_list ) process( name, batch, resps_list )
"""
# evaluate why it's so slow # evaluate why it's so slow
if has_stt: if has_stt:
max_steps = max( [ text.shape[0] for text in batch["text"] ] ) max_steps = max( [ text.shape[0] for text in batch["text"] ] )
@ -254,6 +256,7 @@ def run_eval(engines, eval_name, dl, args=None):
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ] text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
_logger.info(f"Validation Metrics (STT): {text_list}") _logger.info(f"Validation Metrics (STT): {text_list}")
"""
stats = {k: sum(v) / len(v) for k, v in stats.items() if v} stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
engines_stats = { engines_stats = {