This commit is contained in:
mrq 2025-01-05 23:53:17 -06:00
parent b445f4abb6
commit 3ab11bdc7b
3 changed files with 5 additions and 3 deletions

View File

@ -1024,6 +1024,8 @@ class Dataset(_Dataset):
@cached_property
def tasks(self):
if not self.training:
return ["tts"]
return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
def save_state_dict(self, path = None):

View File

@ -719,8 +719,6 @@ class AR_NAR(Base):
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
print( text_list, raw_text_list )
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
inputs = self.inputs(

View File

@ -29,7 +29,7 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch, teacher=None):
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
batch_size = len(batch["text"])
engine.current_batch_size = batch_size
@ -184,6 +184,8 @@ def run_eval(engines, eval_name, dl, args=None):
# has_stt = True
batch["task"][i] = "tts"
batch["proms"][i] = batch["resps"][i][:75*3, :]
elif task != "tts":
batch["task"][i] = "tts"
# random prompts requested
if args and args.eval_random_text_prompts and eval_name == "subtrain":