maybe fix NaNs being thrown for immature models at fp16 for training evals
This commit is contained in:
parent
0f39f4d7a1
commit
3330b5bb00
|
@ -207,6 +207,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
training=False,
|
||||
)
|
||||
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
if engine.hyper_config.version >= 7:
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
# sample for NAR demask
|
||||
|
@ -241,6 +242,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
|
||||
process( name, batch, resps_list )
|
||||
|
||||
"""
|
||||
# evaluate why it's so slow
|
||||
if has_stt:
|
||||
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
|
||||
|
@ -254,6 +256,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
|
||||
|
||||
_logger.info(f"Validation Metrics (STT): {text_list}")
|
||||
"""
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
|
||||
engines_stats = {
|
||||
|
|
Loading…
Reference in New Issue
Block a user