evaluation/validation passes language ID during training (oops)
This commit is contained in:
parent
ed54f4ebec
commit
0aa2a3cc07
|
@ -224,7 +224,7 @@ class Dataset(_Dataset):
|
|||
self.spkrs_by_spkr_group[spkr_group].append( spkr )
|
||||
|
||||
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
||||
|
||||
|
||||
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
|
||||
if self.sampler_type == "path":
|
||||
|
|
|
@ -103,9 +103,9 @@ def run_eval(engines, eval_name, dl):
|
|||
if AR is not None and NAR is not None:
|
||||
name = "+".join(names)
|
||||
|
||||
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
||||
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||
|
||||
process( name, batch, resps_list )
|
||||
else:
|
||||
|
@ -113,13 +113,14 @@ def run_eval(engines, eval_name, dl):
|
|||
model = engines[name]
|
||||
|
||||
if name.startswith("ar+nar"):
|
||||
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
||||
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||
elif name.startswith("ar"):
|
||||
resps_list = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
lang_list=batch["lang"],
|
||||
max_steps=cfg.evaluation.steps,
|
||||
sampling_temperature=cfg.evaluation.ar_temperature,
|
||||
)
|
||||
|
@ -128,6 +129,7 @@ def run_eval(engines, eval_name, dl):
|
|||
resps_list = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
lang_list=batch["lang"],
|
||||
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
|
||||
sampling_temperature=cfg.evaluation.nar_temperature,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user