diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 278a9e3..933dc37 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -912,13 +912,13 @@ class Base(nn.Module): embedding = None if name == "text": embedding = self.text_emb( input ) - elif name == "quant_level": + elif name == "quant_level" and self.rvq_level_emb is not None: embedding = self.rvq_level_emb( input ) - elif name == "lang": + elif name == "lang" and self.langs_emb is not None: embedding = self.langs_emb( input ) elif name == "prom": embedding = self.proms_emb( input ) - elif name == "tone": + elif name == "tone" and self.tones_emb is not None: embedding = self.tones_emb( input ) elif name == "resp": embedding = self.resps_emb( input, quant_level ) diff --git a/vall_e/train.py b/vall_e/train.py index ba2e3e1..d8bcb48 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -168,9 +168,14 @@ def run_eval(engines, eval_name, dl): for i, resp in enumerate( resps_list ): resps_list[i] = torch.stack( resp ).t() else: - resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) - resps_list = [ r.unsqueeze(-1) for r in resps_list ] - resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) + if "ar" in engine.hyper_config.capabilities: + resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) + else: + resps_list = [ resp[:, 0] for resp in batch["resps"] ] + + if "nar" in engine.hyper_config.capabilities: + resps_list = [ r.unsqueeze(-1) for r in resps_list ] + resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) process( name, batch, resps_list ) diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 685a8f0..6c3d0b9 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -162,7 +162,7 @@ def train( #batch = to_device(batch, torch.cuda.current_device()) stats = engines.step(batch=batch, feeder=train_feeder) - stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths) * world_size() + stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size()) """ stats['batch'] = {