oops
This commit is contained in:
parent
b445f4abb6
commit
3ab11bdc7b
|
@ -1024,6 +1024,8 @@ class Dataset(_Dataset):
|
|||
|
||||
@cached_property
|
||||
def tasks(self):
|
||||
if not self.training:
|
||||
return ["tts"]
|
||||
return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
|
||||
|
||||
def save_state_dict(self, path = None):
|
||||
|
|
|
@ -719,8 +719,6 @@ class AR_NAR(Base):
|
|||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||
|
||||
print( text_list, raw_text_list )
|
||||
|
||||
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
||||
|
||||
inputs = self.inputs(
|
||||
|
|
|
@ -29,7 +29,7 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
|
|||
def train_feeder(engine, batch, teacher=None):
|
||||
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
|
||||
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||
|
||||
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
batch_size = len(batch["text"])
|
||||
engine.current_batch_size = batch_size
|
||||
|
@ -184,6 +184,8 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
# has_stt = True
|
||||
batch["task"][i] = "tts"
|
||||
batch["proms"][i] = batch["resps"][i][:75*3, :]
|
||||
elif task != "tts":
|
||||
batch["task"][i] = "tts"
|
||||
|
||||
# random prompts requested
|
||||
if args and args.eval_random_text_prompts and eval_name == "subtrain":
|
||||
|
|
Loading…
Reference in New Issue
Block a user