maybe fix NaNs being thrown for immature models at fp16 for training evals
This commit is contained in:
parent
0f39f4d7a1
commit
3330b5bb00
|
@ -207,6 +207,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
training=False,
|
training=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||||
if engine.hyper_config.version >= 7:
|
if engine.hyper_config.version >= 7:
|
||||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
# sample for NAR demask
|
# sample for NAR demask
|
||||||
|
@ -241,6 +242,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
|
|
||||||
process( name, batch, resps_list )
|
process( name, batch, resps_list )
|
||||||
|
|
||||||
|
"""
|
||||||
# evaluate why it's so slow
|
# evaluate why it's so slow
|
||||||
if has_stt:
|
if has_stt:
|
||||||
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
|
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
|
||||||
|
@ -254,6 +256,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
|
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
|
||||||
|
|
||||||
_logger.info(f"Validation Metrics (STT): {text_list}")
|
_logger.info(f"Validation Metrics (STT): {text_list}")
|
||||||
|
"""
|
||||||
|
|
||||||
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
|
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
|
||||||
engines_stats = {
|
engines_stats = {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user