more fixes, moved sampler state dict to a better place, eval works again
This commit is contained in:
parent
4bd9bb39c8
commit
fa93061b3e
|
@ -734,7 +734,7 @@ class Dataset(_Dataset):
|
|||
|
||||
@cached_property
|
||||
def sampler_state_dict_path(self):
|
||||
return cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||
return cfg.ckpt_dir / cfg.model.full_name / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||
|
||||
def get_speaker(self, path):
|
||||
if isinstance(path, str):
|
||||
|
|
|
@ -103,19 +103,25 @@ def run_eval(engines, eval_name, dl):
|
|||
for key in batch.keys():
|
||||
batch[key] = batch[key][:cfg.evaluation.batch_size]
|
||||
|
||||
processed += len(batch["text"])
|
||||
batch_size = len(batch["text"])
|
||||
|
||||
processed += batch_size
|
||||
|
||||
for name in engines:
|
||||
engine = engines[name]
|
||||
|
||||
# to-do: eval for text tasks
|
||||
for i, task in batch["task"]:
|
||||
has_stt = False
|
||||
for i, task in enumerate( batch["task"] ):
|
||||
# easier to just change it to a tts task than drop stt tasks from the batch
|
||||
if task == "stt":
|
||||
has_stt = True
|
||||
batch["task"][i] = "tts"
|
||||
batch["proms"][i] = batch["resps"][i][:75*3, :]
|
||||
|
||||
kwargs = dict(
|
||||
text_list=batch["text"],
|
||||
prom_list=batch["proms"],
|
||||
proms_list=batch["proms"],
|
||||
lang_list=batch["lang"],
|
||||
task_list=batch["task"],
|
||||
)
|
||||
|
@ -137,6 +143,20 @@ def run_eval(engines, eval_name, dl):
|
|||
|
||||
process( name, batch, resps_list )
|
||||
|
||||
# evaluate why it's so slow
|
||||
if has_stt:
|
||||
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
|
||||
|
||||
kwargs["text_list"] = None
|
||||
kwargs["task_list"] = [ "stt" for _ in range(batch_size) ]
|
||||
kwargs["proms_list"] = [ ["stt"] for _ in range(batch_size) ]
|
||||
kwargs["resps_list"] = batch["resps"]
|
||||
|
||||
text_list = engine( **kwargs, max_steps=max_steps, sampling_temperature=0.0)
|
||||
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
|
||||
|
||||
_logger.info(f"Validation Metrics (STT): {text_list}")
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
engines_stats = {
|
||||
f'{name}.{eval_name}': stats,
|
||||
|
|
Loading…
Reference in New Issue
Block a user