diff --git a/vall_e/data.py b/vall_e/data.py index cb96e2f..c3584f4 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1024,6 +1024,8 @@ class Dataset(_Dataset): @cached_property def tasks(self): + if not self.training: + return ["tts"] return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse" def save_state_dict(self, path = None): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 2d15623..4c206e9 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -719,8 +719,6 @@ class AR_NAR(Base): text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] - print( text_list, raw_text_list ) - quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ] inputs = self.inputs( diff --git a/vall_e/train.py b/vall_e/train.py index 0a28ebe..20cf753 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -29,7 +29,7 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu") def train_feeder(engine, batch, teacher=None): engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ]) engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ]) - + with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): batch_size = len(batch["text"]) engine.current_batch_size = batch_size @@ -184,6 +184,8 @@ def run_eval(engines, eval_name, dl, args=None): # has_stt = True batch["task"][i] = "tts" batch["proms"][i] = batch["resps"][i][:75*3, :] + elif task != "tts": + batch["task"][i] = "tts" # random prompts requested if args and args.eval_random_text_prompts and eval_name == "subtrain":