evaluation/validation passes language ID during training (oops)
This commit is contained in:
parent
ed54f4ebec
commit
0aa2a3cc07
|
@ -224,7 +224,7 @@ class Dataset(_Dataset):
|
||||||
self.spkrs_by_spkr_group[spkr_group].append( spkr )
|
self.spkrs_by_spkr_group[spkr_group].append( spkr )
|
||||||
|
|
||||||
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
||||||
|
|
||||||
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
|
|
|
@ -103,9 +103,9 @@ def run_eval(engines, eval_name, dl):
|
||||||
if AR is not None and NAR is not None:
|
if AR is not None and NAR is not None:
|
||||||
name = "+".join(names)
|
name = "+".join(names)
|
||||||
|
|
||||||
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||||
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
||||||
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||||
|
|
||||||
process( name, batch, resps_list )
|
process( name, batch, resps_list )
|
||||||
else:
|
else:
|
||||||
|
@ -113,13 +113,14 @@ def run_eval(engines, eval_name, dl):
|
||||||
model = engines[name]
|
model = engines[name]
|
||||||
|
|
||||||
if name.startswith("ar+nar"):
|
if name.startswith("ar+nar"):
|
||||||
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||||
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
||||||
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||||
elif name.startswith("ar"):
|
elif name.startswith("ar"):
|
||||||
resps_list = model(
|
resps_list = model(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
proms_list=batch["proms"],
|
proms_list=batch["proms"],
|
||||||
|
lang_list=batch["lang"],
|
||||||
max_steps=cfg.evaluation.steps,
|
max_steps=cfg.evaluation.steps,
|
||||||
sampling_temperature=cfg.evaluation.ar_temperature,
|
sampling_temperature=cfg.evaluation.ar_temperature,
|
||||||
)
|
)
|
||||||
|
@ -128,6 +129,7 @@ def run_eval(engines, eval_name, dl):
|
||||||
resps_list = model(
|
resps_list = model(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
proms_list=batch["proms"],
|
proms_list=batch["proms"],
|
||||||
|
lang_list=batch["lang"],
|
||||||
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
|
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
|
||||||
sampling_temperature=cfg.evaluation.nar_temperature,
|
sampling_temperature=cfg.evaluation.nar_temperature,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user