maybe fix NaNs being thrown for immature models at fp16 for training evals
This commit is contained in:
parent
0f39f4d7a1
commit
3330b5bb00
|
@ -207,40 +207,42 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
training=False,
|
||||
)
|
||||
|
||||
if engine.hyper_config.version >= 7:
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
# sample for NAR demask
|
||||
if random.random() < engine.hyper_config.experimental.masking_train_p:
|
||||
kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ]
|
||||
# inference
|
||||
resps_list = engine( **kwargs )
|
||||
else:
|
||||
if "len" in engine.hyper_config.capabilities:
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
if engine.hyper_config.version >= 7:
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
max_steps = kwargs.pop("max_steps", 500)
|
||||
|
||||
if "denoise_start" in kwargs:
|
||||
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
||||
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
||||
else:
|
||||
len_list = engine( max_steps=5, **kwargs )
|
||||
len_list = [ min( l, max_steps ) for l in len_list ]
|
||||
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs, len_list=len_list )
|
||||
# sample for NAR demask
|
||||
if random.random() < engine.hyper_config.experimental.masking_train_p:
|
||||
kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ]
|
||||
# inference
|
||||
resps_list = engine( **kwargs )
|
||||
else:
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
if "len" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs )
|
||||
else:
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
max_steps = kwargs.pop("max_steps", 500)
|
||||
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
if "denoise_start" in kwargs:
|
||||
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
||||
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
||||
else:
|
||||
len_list = engine( max_steps=5, **kwargs )
|
||||
len_list = [ min( l, max_steps ) for l in len_list ]
|
||||
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs, resps_list=resps_list )
|
||||
resps_list = engine( **kwargs, len_list=len_list )
|
||||
else:
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs )
|
||||
else:
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||
resps_list = engine( **kwargs, resps_list=resps_list )
|
||||
|
||||
process( name, batch, resps_list )
|
||||
|
||||
"""
|
||||
# evaluate why it's so slow
|
||||
if has_stt:
|
||||
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
|
||||
|
@ -254,6 +256,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
|
||||
|
||||
_logger.info(f"Validation Metrics (STT): {text_list}")
|
||||
"""
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
|
||||
engines_stats = {
|
||||
|
|
Loading…
Reference in New Issue
Block a user