maybe fix NaNs being thrown for immature models at fp16 for training evals
This commit is contained in:
parent
0f39f4d7a1
commit
3330b5bb00
|
@ -207,40 +207,42 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
training=False,
|
training=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if engine.hyper_config.version >= 7:
|
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
if engine.hyper_config.version >= 7:
|
||||||
# sample for NAR demask
|
|
||||||
if random.random() < engine.hyper_config.experimental.masking_train_p:
|
|
||||||
kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ]
|
|
||||||
# inference
|
|
||||||
resps_list = engine( **kwargs )
|
|
||||||
else:
|
|
||||||
if "len" in engine.hyper_config.capabilities:
|
|
||||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
max_steps = kwargs.pop("max_steps", 500)
|
# sample for NAR demask
|
||||||
|
if random.random() < engine.hyper_config.experimental.masking_train_p:
|
||||||
if "denoise_start" in kwargs:
|
kwargs["len_list"] = [ resp.shape[0] for resp in batch["resps"] ]
|
||||||
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
# inference
|
||||||
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
resps_list = engine( **kwargs )
|
||||||
else:
|
|
||||||
len_list = engine( max_steps=5, **kwargs )
|
|
||||||
len_list = [ min( l, max_steps ) for l in len_list ]
|
|
||||||
|
|
||||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
|
||||||
resps_list = engine( **kwargs, len_list=len_list )
|
|
||||||
else:
|
else:
|
||||||
if "ar" in engine.hyper_config.capabilities:
|
if "len" in engine.hyper_config.capabilities:
|
||||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
resps_list = engine( **kwargs )
|
max_steps = kwargs.pop("max_steps", 500)
|
||||||
else:
|
|
||||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
|
||||||
|
|
||||||
if "nar" in engine.hyper_config.capabilities:
|
if "denoise_start" in kwargs:
|
||||||
|
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
||||||
|
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
||||||
|
else:
|
||||||
|
len_list = engine( max_steps=5, **kwargs )
|
||||||
|
len_list = [ min( l, max_steps ) for l in len_list ]
|
||||||
|
|
||||||
kwargs = base_kwargs | cfg.evaluation.kwargs
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
resps_list = engine( **kwargs, resps_list=resps_list )
|
resps_list = engine( **kwargs, len_list=len_list )
|
||||||
|
else:
|
||||||
|
if "ar" in engine.hyper_config.capabilities:
|
||||||
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
|
resps_list = engine( **kwargs )
|
||||||
|
else:
|
||||||
|
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||||
|
|
||||||
|
if "nar" in engine.hyper_config.capabilities:
|
||||||
|
kwargs = base_kwargs | cfg.evaluation.kwargs
|
||||||
|
resps_list = engine( **kwargs, resps_list=resps_list )
|
||||||
|
|
||||||
process( name, batch, resps_list )
|
process( name, batch, resps_list )
|
||||||
|
|
||||||
|
"""
|
||||||
# evaluate why it's so slow
|
# evaluate why it's so slow
|
||||||
if has_stt:
|
if has_stt:
|
||||||
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
|
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
|
||||||
|
@ -254,6 +256,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
|
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
|
||||||
|
|
||||||
_logger.info(f"Validation Metrics (STT): {text_list}")
|
_logger.info(f"Validation Metrics (STT): {text_list}")
|
||||||
|
"""
|
||||||
|
|
||||||
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
|
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
|
||||||
engines_stats = {
|
engines_stats = {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user