evaluation/validation passes language ID during training (oops)

master
mrq 2023-10-29 12:00:40 +07:00
parent ed54f4ebec
commit 0aa2a3cc07
2 changed files with 7 additions and 5 deletions

@ -224,7 +224,7 @@ class Dataset(_Dataset):
self.spkrs_by_spkr_group[spkr_group].append( spkr )
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
if self.sampler_type == "path":

@ -103,9 +103,9 @@ def run_eval(engines, eval_name, dl):
if AR is not None and NAR is not None:
name = "+".join(names)
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
process( name, batch, resps_list )
else:
@ -113,13 +113,14 @@ def run_eval(engines, eval_name, dl):
model = engines[name]
if name.startswith("ar+nar"):
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
elif name.startswith("ar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
lang_list=batch["lang"],
max_steps=cfg.evaluation.steps,
sampling_temperature=cfg.evaluation.ar_temperature,
)
@ -128,6 +129,7 @@ def run_eval(engines, eval_name, dl):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
lang_list=batch["lang"],
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
sampling_temperature=cfg.evaluation.nar_temperature,
)