This commit is contained in:
mrq 2024-06-05 10:30:04 -05:00
parent 48cd1054f9
commit 3cfc8a96bb
3 changed files with 12 additions and 7 deletions

View File

@ -912,13 +912,13 @@ class Base(nn.Module):
embedding = None
if name == "text":
embedding = self.text_emb( input )
elif name == "quant_level":
elif name == "quant_level" and self.rvq_level_emb is not None:
embedding = self.rvq_level_emb( input )
elif name == "lang":
elif name == "lang" and self.langs_emb is not None:
embedding = self.langs_emb( input )
elif name == "prom":
embedding = self.proms_emb( input )
elif name == "tone":
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
embedding = self.resps_emb( input, quant_level )

View File

@ -168,9 +168,14 @@ def run_eval(engines, eval_name, dl):
for i, resp in enumerate( resps_list ):
resps_list[i] = torch.stack( resp ).t()
else:
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
if "ar" in engine.hyper_config.capabilities:
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
else:
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
if "nar" in engine.hyper_config.capabilities:
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
process( name, batch, resps_list )

View File

@ -162,7 +162,7 @@ def train(
#batch = to_device(batch, torch.cuda.current_device())
stats = engines.step(batch=batch, feeder=train_feeder)
stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths) * world_size()
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
"""
stats['batch'] = {