more fixes, moved sampler state dict to a better place, eval works again
This commit is contained in:
parent
4bd9bb39c8
commit
fa93061b3e
|
@ -734,7 +734,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sampler_state_dict_path(self):
|
def sampler_state_dict_path(self):
|
||||||
return cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
return cfg.ckpt_dir / cfg.model.full_name / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||||
|
|
||||||
def get_speaker(self, path):
|
def get_speaker(self, path):
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
|
|
|
@ -103,19 +103,25 @@ def run_eval(engines, eval_name, dl):
|
||||||
for key in batch.keys():
|
for key in batch.keys():
|
||||||
batch[key] = batch[key][:cfg.evaluation.batch_size]
|
batch[key] = batch[key][:cfg.evaluation.batch_size]
|
||||||
|
|
||||||
processed += len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
|
|
||||||
|
processed += batch_size
|
||||||
|
|
||||||
for name in engines:
|
for name in engines:
|
||||||
engine = engines[name]
|
engine = engines[name]
|
||||||
|
|
||||||
# to-do: eval for text tasks
|
# to-do: eval for text tasks
|
||||||
for i, task in batch["task"]:
|
has_stt = False
|
||||||
|
for i, task in enumerate( batch["task"] ):
|
||||||
|
# easier to just change it to a tts task than drop stt tasks from the batch
|
||||||
if task == "stt":
|
if task == "stt":
|
||||||
|
has_stt = True
|
||||||
batch["task"][i] = "tts"
|
batch["task"][i] = "tts"
|
||||||
|
batch["proms"][i] = batch["resps"][i][:75*3, :]
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
prom_list=batch["proms"],
|
proms_list=batch["proms"],
|
||||||
lang_list=batch["lang"],
|
lang_list=batch["lang"],
|
||||||
task_list=batch["task"],
|
task_list=batch["task"],
|
||||||
)
|
)
|
||||||
|
@ -137,6 +143,20 @@ def run_eval(engines, eval_name, dl):
|
||||||
|
|
||||||
process( name, batch, resps_list )
|
process( name, batch, resps_list )
|
||||||
|
|
||||||
|
# evaluate why it's so slow
|
||||||
|
if has_stt:
|
||||||
|
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
|
||||||
|
|
||||||
|
kwargs["text_list"] = None
|
||||||
|
kwargs["task_list"] = [ "stt" for _ in range(batch_size) ]
|
||||||
|
kwargs["proms_list"] = [ ["stt"] for _ in range(batch_size) ]
|
||||||
|
kwargs["resps_list"] = batch["resps"]
|
||||||
|
|
||||||
|
text_list = engine( **kwargs, max_steps=max_steps, sampling_temperature=0.0)
|
||||||
|
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
|
||||||
|
|
||||||
|
_logger.info(f"Validation Metrics (STT): {text_list}")
|
||||||
|
|
||||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||||
engines_stats = {
|
engines_stats = {
|
||||||
f'{name}.{eval_name}': stats,
|
f'{name}.{eval_name}': stats,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user