From d33a9061198cba43e1a16e54f957a8713421ff85 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 6 Sep 2024 14:30:12 -0500 Subject: [PATCH] cleanup for AR_NAR inferencing to allow both TTS and STT tasks simultaneously (need to have training eval do this to though) --- vall_e/models/ar_nar.py | 109 ++++++++++++++++++++++++++++------------ vall_e/train.py | 23 ++++++--- 2 files changed, 93 insertions(+), 39 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 014e677..c985e85 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -76,8 +76,15 @@ class AR_NAR(Base): if task_list is None: task_list = [ default_task for _ in range(batch_size) ] + has_none = resps_list is None or text_list is None + if not has_none: + for i, task in enumerate( task_list ): + if resps_list[i] is None or text_list[i] is None: + has_none = True + break + # is training or NAR - if resps_list is not None and text_list is not None: + if not has_none: n_levels_set = {r.shape[-1] for r in resps_list} n_levels = next(iter(n_levels_set)) @@ -241,7 +248,8 @@ class AR_NAR(Base): sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] stopped = torch.zeros(batch_size, device=device).bool() - stop_token = self.stop_token if task_list[0] != "stt" else 2 # to-do: derive from tokenizer + audio_stop_token = self.stop_token + text_stop_token = 2 state = None mirostat = [ @@ -257,10 +265,15 @@ class AR_NAR(Base): # get next in sequence for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): - if task_list[0] in text_task: - text_list = [x for x in sequence_list] - else: - resps_list = [x.unsqueeze(dim=-1) for x in sequence_list] + # + text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] + resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] + + """ + print( "task_list:", task_list ) + print( "text_list:", text_list ) + print( "resps_list:", resps_list ) + """ inputs = self.inputs( text_list=text_list, @@ -286,7 +299,7 @@ class AR_NAR(Base): r = super().sample( logits=logits, - prev_list=resps_list, + prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], temperature=sampling_temperature, min_temperature=sampling_min_temperature, @@ -325,12 +338,14 @@ class AR_NAR(Base): # append tokens for i, ri in enumerate(r): + task = task_list[i] + stop_token = audio_stop_token if task not in text_task else text_stop_token if stop_token in ri: stopped[i] = True sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) # stop token found - stopped |= r == stop_token + # stopped |= r == stop_token if stopped.all().item(): break @@ -339,7 +354,10 @@ class AR_NAR(Base): if sampling_beam_width: sequence_list = [ sequence_list[0] ] - sequence_list = [self._prune(r, stop_token) for r in sequence_list] + # remove stop token + sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)] + # remove + sequence_list = [ sequence_list[i] if task not in text_task else sequence_list[i][1:] for i, task in enumerate( task_list ) ] return sequence_list @@ -426,7 +444,8 @@ def example_usage(): """ bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) - available_tasks = cfg.dataset.tasks_list + #available_tasks = cfg.dataset.tasks_list + available_tasks = ["tts", "stt"] model = AR_NAR(**kwargs).to(device) steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size @@ -515,6 +534,14 @@ def example_usage(): @torch.no_grad() def sample_data(t=None): + if isinstance(t, list): + tasks = t + texts = [ text_list[0].to(device) if task != "stt" else None for i, task in enumerate( tasks ) ] + proms = [ proms_list[0].to(device) if task != "stt" else [ "stt" ] for i, task in enumerate( tasks ) ] + resps = [ None if task != "stt" else resps_list[0].to(device) for i, task in enumerate( tasks ) ] + + return texts, proms, resps, tasks + texts = [] proms = [] resps = [] @@ -523,25 +550,32 @@ def example_usage(): for i in range(batch_size): task = random.choice(available_tasks) if t is None else t - text = text_list[i] - prom = proms_list[i] - resp = resps_list[i] + text = text_list[i].to(device) + prom = proms_list[i].to(device) + resp = resps_list[i].to(device) # do nothing if task == "tts": ... elif task == "stt": - ... + prom = [ + task + ] + # to-do: reimplement this from data.py + """ elif task == "tts-c": trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) prom = resp[:trim_length] resp = resp[trim_length:] + + prom = prom.to(device) elif task == "ns" or task == "sr": # extend the noise to fill the target audio noise_ext = repeat_extend_audio( noise, resp.shape[0] ) # create the input prompt by merging the target audio with the noise prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device ) + prom = prom.to(device) # set the target to just be the noise if if task == "sr": resp = noise_ext @@ -550,9 +584,15 @@ def example_usage(): if random.random() < 0.5: text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8) - texts.append( text.to(device) ) - proms.append( prom.to(device) ) - resps.append( resp.to(device) ) + prom = [ + task, + prom, + ] + """ + + texts.append( text ) + proms.append( prom ) + resps.append( resp ) tasks.append( task ) return texts, proms, resps, tasks @@ -563,25 +603,25 @@ def example_usage(): texts, proms, resps, tasks = sample_data( task ) - if tasks[0] == "stt": - text = engine( None, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 ) - """ - # to-do: STT for NAR - text = engine( text, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 ) - """ - text = [ cfg.tokenizer.decode( t ) for t in text ] + if "ar" in cfg.model.capabilities: + output = engine( texts, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 ) - print( text ) + text = [ cfg.tokenizer.decode( output[i] ) for i, task in enumerate( tasks ) if task == "stt" ] + + texts = [ texts[i] for i, task in enumerate( tasks ) if task != "stt" ] + proms = [ proms[i] for i, task in enumerate( tasks ) if task != "stt" ] + resps = [ output[i] for i, task in enumerate( tasks ) if task != "stt" ] + tasks = [ tasks[i] for i, task in enumerate( tasks ) if task != "stt" ] + + print( "STT:", text ) else: - if "ar" in cfg.model.capabilities: - resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 ) - else: - resps = [ resp[:, 0] for resp in resps ] + resps = [ resp[:, 0] for resp in resps ] - if "nar" in cfg.model.capabilities: - resps = engine( texts, proms, resps, sampling_temperature=0.2 ) - for i, o in enumerate(resps): - _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device) + if "nar" in cfg.model.capabilities: + resps = engine( texts, proms, resps, task_list=tasks, sampling_temperature=0.2 ) + + for i, o in enumerate(resps): + _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device) unload_model() @@ -611,8 +651,11 @@ def example_usage(): model = ml.compile_model(model, backend=cfg.optimizations.compile) """ + """ for task in available_tasks: sample("final", task=task) + """ + sample("final", task=available_tasks) engines.quit() diff --git a/vall_e/train.py b/vall_e/train.py index 0a60613..1f76dea 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -108,24 +108,35 @@ def run_eval(engines, eval_name, dl): for name in engines: engine = engines[name] + # to-do: eval for text tasks + for i, task in batch["task"]: + if task == "stt": + batch["task"][i] = "tts" + + kwargs = dict( + text_list=batch["text"], + prom_list=batch["proms"], + lang_list=batch["lang"], + task_list=batch["task"], + ) + if engine.hyper_config.experimental.hf: - resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"] ) + resps_list = engine( **kwargs ) elif "len" in engine.hyper_config.capabilities: - len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that + len_list = engine( **kwargs, max_steps=10 ) # don't need more than that len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ] - resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels ) + resps_list = engine( **kwargs, len_list=len_list, max_levels=cfg.evaluation.nar_levels ) else: if "ar" in engine.hyper_config.capabilities: - resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) + resps_list = engine( **kwargs, max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) else: resps_list = [ resp[:, 0] for resp in batch["resps"] ] if "nar" in engine.hyper_config.capabilities: - resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels ) + resps_list = engine( **kwargs, resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels ) process( name, batch, resps_list ) - stats = {k: sum(v) / len(v) for k, v in stats.items()} engines_stats = { f'{name}.{eval_name}': stats,