more fixes, moved sampler state dict to a better place, eval works again

This commit is contained in:
mrq 2024-09-06 16:59:56 -05:00
parent 4bd9bb39c8
commit fa93061b3e
2 changed files with 24 additions and 4 deletions

View File

@ -734,7 +734,7 @@ class Dataset(_Dataset):
@cached_property
def sampler_state_dict_path(self):
return cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
return cfg.ckpt_dir / cfg.model.full_name / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
def get_speaker(self, path):
if isinstance(path, str):

View File

@ -103,19 +103,25 @@ def run_eval(engines, eval_name, dl):
for key in batch.keys():
batch[key] = batch[key][:cfg.evaluation.batch_size]
processed += len(batch["text"])
batch_size = len(batch["text"])
processed += batch_size
for name in engines:
engine = engines[name]
# to-do: eval for text tasks
for i, task in batch["task"]:
has_stt = False
for i, task in enumerate( batch["task"] ):
# easier to just change it to a tts task than drop stt tasks from the batch
if task == "stt":
has_stt = True
batch["task"][i] = "tts"
batch["proms"][i] = batch["resps"][i][:75*3, :]
kwargs = dict(
text_list=batch["text"],
prom_list=batch["proms"],
proms_list=batch["proms"],
lang_list=batch["lang"],
task_list=batch["task"],
)
@ -137,6 +143,20 @@ def run_eval(engines, eval_name, dl):
process( name, batch, resps_list )
# evaluate why it's so slow
if has_stt:
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
kwargs["text_list"] = None
kwargs["task_list"] = [ "stt" for _ in range(batch_size) ]
kwargs["proms_list"] = [ ["stt"] for _ in range(batch_size) ]
kwargs["resps_list"] = batch["resps"]
text_list = engine( **kwargs, max_steps=max_steps, sampling_temperature=0.0)
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
_logger.info(f"Validation Metrics (STT): {text_list}")
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats = {
f'{name}.{eval_name}': stats,