From 67a6009555d9f9cf56f102f7dc1f1dfef1a802f8 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 23 Feb 2025 08:31:03 -0600 Subject: [PATCH] (finally) added parallel AR for cfg.model.version >= 7 (nvidia/audio-codec-44khz is being a pain and it might require training purely AR first......) --- vall_e/data.py | 27 +++--- vall_e/models/ar_nar.py | 176 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 189 insertions(+), 14 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 40e15db..c1bf52d 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -757,21 +757,20 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): if len(metadata) == 0: return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate ) + # this might be slow + def _exists( id, entry ): + if not cfg.dataset.strict_validate: + return True + + if cfg.dataset.use_hdf5: + return key(id, entry) in cfg.hdf5 + + return (data_dir / id).with_suffix(_get_artifact_extension()).exists() + def _validate( id, entry ): phones = entry['phones'] if "phones" in entry else 0 duration = entry['duration'] if "duration" in entry else 0 - k = key(id, entry) - - # double check if in HDF5 - # this might be slow - if cfg.dataset.strict_validate: - if cfg.dataset.use_hdf5: - if k not in cfg.hdf5: - return False - elif not (data_dir / id).with_suffix(_get_artifact_extension()).exists(): - return False - # add to duration bucket if type not in _durations_map: _durations_map[type] = {} @@ -780,7 +779,11 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): if not validate: return True - return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration + in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration + if in_bounds and not _exists( id, entry ): + return False + + return in_bounds return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 326f96f..8b98aff 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -1042,6 +1042,158 @@ class AR_NAR(Base): return sequence_list + def forward_ar_parallel( + self, + + task_list: list[Tensor], + + text_list: list[Tensor] | None = None, + raw_text_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, + resps_list: list[Tensor] | None = None, + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + + disable_tqdm=False, + use_lora=None, + **sampling_kwargs, + ): + # deduce batch_size + if text_list: + device = text_list[0].device + batch_size = len(text_list) + elif raw_text_list: + device = raw_text_list[0].device + batch_size = len(raw_text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) + elif resps_list: + device = resps_list[0].device + batch_size = len(resps_list) + + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) + + # convert AR specific args + sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" ) + + temperature = sampling_kwargs.get("temperature", 1.0) + cfg_strength = sampling_kwargs.get("cfg_strength", 0.0) + cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7) + min_temperature = sampling_kwargs.get("min_temperature", -1.0) + max_duration = sampling_kwargs.get("max_duration", 500) + beam_width = sampling_kwargs.get("beam_width", 0) + entropix_sampling = sampling_kwargs.get("entropix_sampling", False) + refine_on_stop = sampling_kwargs.get("refine_on_stop", False) + input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False) + layer_skip = sampling_kwargs.get("layer_skip", False) + prefix_silence = sampling_kwargs.get("prefix_silence", 0.0) + mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0) + mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0) + + start_slice = [ 0 for _ in range(batch_size) ] + sequence_list = [ torch.zeros((0, 8), device=device).to(torch.int16) for _ in range(batch_size) ] + stopped = torch.zeros(batch_size, device=device).bool() + + audio_stop_token = self.stop_token + text_stop_token = 2 + + state = None + mirostat = [ + {"n": 1024, "tau": mirostat_tau, "eta": mirostat_eta, "max_surprise": mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0} + ] * batch_size if mirostat_tau > 0.0 else None + + scores = [ 1.0 ] * beam_width + metrics = [] + + null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ] + null_prom = [ None for _ in range(batch_size) ] + + # get next in sequence + iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm) + for n in iterator: + if raw_text_list is not None: + raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ] + else: + text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] + resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] + + quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ] + + inputs = self.inputs( + task_list=task_list, + + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + raw_text_list=raw_text_list, + + quant_levels=quant_levels, + ) + + # to-do: find an elegant way to write this + output = super().forward( + inputs=inputs, + state=state, + #layer_skip_variables=sampling_layer_skip_variables, + output_attentions=entropix_sampling, + ) + + if cfg_strength > 0: + null_inputs = super().inputs( + text_list=null_text, + proms_list=null_prom, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + quant_levels=quant_levels, + ) + null_output = super().forward( + inputs=null_inputs, + quant_levels=quant_levels, + #layer_skip_variables=sampling_layer_skip_variables, + ) + logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] ) + + logits, state = output.logits, output.state + + l_resps_list = [ [] for _ in range(batch_size) ] + for l in range(self.n_resp_levels): + sampled = super().sample( + logits=[ logit[l] for logit in logits ], + prev_list=[ resp[..., l] for resp in resps_list ], + **(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}), + ) + + ids = sampled.ids + + # append tokens + for i, token in enumerate(ids): + if audio_stop_token in token: + stopped[i] = True + l_resps_list[i].append(token.to(device)) + + for i, l in enumerate(l_resps_list): + sequence_list[i] = torch.cat([sequence_list[i], torch.stack(l, dim=-1)]) + + # stop token found + # stopped |= r == stop_token + if stopped.all().item(): + iterator.close() + break + + for i, l in enumerate( sequence_list ): + index = (l == audio_stop_token).nonzero()[:, 0].min() + sequence_list[i] = sequence_list[i][:index] + + return sequence_list + def forward( self, task_list: list[Tensor] | None = None, @@ -1169,6 +1321,25 @@ class AR_NAR(Base): **sampling_kwargs, ) + if self.version >= 7: + if task_list is None or task_list[0] != "len": + return self.forward_ar_parallel( + task_list=task_list, + + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + raw_text_list=raw_text_list, + + disable_tqdm=disable_tqdm, + use_lora=use_lora, + **sampling_kwargs, + ) + # is AR return self.forward_ar( task_list=task_list, @@ -1407,7 +1578,8 @@ def example_usage(): resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list ) else: resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 ) - resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 ) + if resps_list[0].dim() == 1 or resps_list[0].shape[-1] == 1: + resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 ) for i, o in enumerate(resps_list): print( o.shape, o ) @@ -1444,7 +1616,7 @@ def example_usage(): """ for task in available_tasks: - sample("final", task="tts-nar") + sample("final", task=task) engines.quit()