diff --git a/docs/models.md b/docs/models.md index 4869b74..79ad179 100644 --- a/docs/models.md +++ b/docs/models.md @@ -233,34 +233,13 @@ This script aims to implement everything as required per VALL-E agnostically, to ## `models/ar_nar.py` -This script implements VALL-E as a unified autoregressive and non-autoregressive model, where RVQ-level 0 is inferenced autoregressively, the remaining levels are infereneced non-autoregressively. - -By default, this is the default model, but is used through `cfg.model.capabilities = ["ar", "nar"]`. +This script implements VALL-E as a unified autoregressive and non-autoregressive model, where RVQ-level 0 is inferenced autoregressively, the remaining levels are infereneced non-autoregressively, if requested. +* Since one model can be trained AR-ly and NAR-ly, RVQ-level 0 can also be trained non-autoregressively with diffusion-like masking. For training, this model handles preparing the batch provided through the dataloader according to a randomly sampled targetted RVQ-level. For inferencing, this will dynamically inference depending on the arguments provided. -## `models/ar.py` - -This script implements VALL-E as a pure autoregressive (AR) model. - -If `cfg.model.experimental.interleave=True`, this makes use of interleaving its audio codes, instead of inferencing per-codebook level. If not, this simply attends to RVQ level 0. - -This model serves as an experiment that failed, and might be revisited in the future. - -Use of this is governed through `cfg.model.capabilities = ["ar"]` - -## `models/nar.py` - -This script implements VALL-E as a mostly-pure non-autoregresive model, where it infers the duration autoregressively (if `"len" in cfg.model.capabilities`). If not, this simply attends to RVQ levels 1+. - -This makes use of training an additional `len` task that can infer the duration of a requested input, as well as (maybe) using special tokens as the initial input for RVQ-level 0 (the level the AR attends to). - -This model serves as an experiment that failed, and might be revisited in the future. - -Use of this is governed through `cfg.model.capabilities = ["nar"]` - ## `models/experimental.py` This script implements VALL-E as a mostly-HuggingFace compatible model, where it handles processing tokens as a uniform sequence of IDs. diff --git a/vall_e/config.py b/vall_e/config.py index f99f093..513af97 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -255,13 +255,13 @@ class ModelExperimentalSettings: # it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token # RetNet's chunked inferencing might be a better place for this - len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len - # to-to: just incorporate this as a task instead + masking_train_p: float = 0.0 # odds of training with masking + masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on # classifier-free guidance shit - cfg_cond_dropout_p: float = 0.2 # probability to drop out text and audio during training - cfg_text_dropout_p: float = 0.0 # probability to drop out input audio prompt during training - cfg_prom_dropout_p: float = 0.3 # probability to drop out input audio prompt during training + cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training + cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training + cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training layerskip: bool = False # layerskip compatible model (or training for) #layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters) @@ -757,6 +757,7 @@ class Config(BaseConfig): device: str = "cuda" # target device mode: str = "training" # "inferencing" experimental: bool = False # debug flag + silent_errors: bool = False # if False, raise exceptions on errors that could silently lead to problems, if True ignore them dataset: Dataset = field(default_factory=lambda: Dataset) models: dict | list | None = field(default_factory=lambda: []) @@ -879,7 +880,12 @@ class Config(BaseConfig): if data_parent.exists(): return [ path.parent / child.name for child in Path(data_parent).glob(path.name) ] - return path + # return an empty list + if self.silent_errors: + return [] + + # raise an error to avoid headaches + raise Exception(f'Cannot unglob requested path: {path}') def format( self, training=True ): @@ -957,10 +963,6 @@ class Config(BaseConfig): model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"] del model["experimental"]["p_rvq_levels"] - if "p_len_train" in model["experimental"]: - model["experimental"]["len_train_p"] = model["experimental"]["p_len_train"] - del model["experimental"]["p_len_train"] - self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ] self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ] @@ -999,22 +1001,17 @@ class Config(BaseConfig): if self.tokenizer == "naive": self.tokenizer = NaiveTokenizer() else: - # ick... - try: - from transformers import PreTrainedTokenizerFast + from transformers import PreTrainedTokenizerFast - tokenizer_path = self.rel_path / self.tokenizer_path - if tokenizer_path and not tokenizer_path.exists(): - tokenizer_path = Path("./data/") / self.tokenizer_path - - if tokenizer_path and tokenizer_path.exists(): - self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) - else: - self.tokenizer = NaiveTokenizer() - except Exception as e: - self.tokenizer = NaiveTokenizer() - _logger.warning(f"Error while parsing tokenizer: {str(e)}") - pass + tokenizer_path = self.rel_path / self.tokenizer_path + # deduce path if a local copy is not provided + if not tokenizer_path.exists(): + tokenizer_path = Path("./data/") / self.tokenizer_path + + if not self.silent_errors and not tokenizer_path.exists(): + raise Exception(f'Tokenizer path not found: {tokenizer_path}') + + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) # Preserves the old behavior @@ -1071,8 +1068,9 @@ cfg = Config.from_cli() try: cfg.format() except Exception as e: + if not cfg.silent_errors: + raise e # throw an error because I'm tired of silent errors messing things up for me _logger.error(f"Error while parsing config YAML: {str(e)}") - raise e # throw an error because I'm tired of silent errors messing things up for me if __name__ == "__main__": print(cfg) diff --git a/vall_e/data.py b/vall_e/data.py index 81ccd8e..094b366 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1199,6 +1199,10 @@ class Dataset(_Dataset): task ] + # Duration prediction ( => len()) + elif task == "len": + proms = self.sample_prompts(spkr_name, reference=path) + # noise suppression (? => ) # speech removal (? => ) elif task == "ns" or task == "sr": diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 737a2a8..0a041e4 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -193,7 +193,7 @@ def load_engines(training=True, **model_kwargs): ("text_emb.weight", model.config.text_tokens ), ("tasks_emb.weight", model.config.tasks ), ("langs_emb.weight", model.config.langs ), - ("rvq_l_emb.weight", model.config.resp_levels + (1 if "len" in model.config.capabilities else 0) ), + ("rvq_l_emb.weight", model.config.resp_levels ), ("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ), ("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ), ("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ), diff --git a/vall_e/inference.py b/vall_e/inference.py index 8effdb0..8a39d8a 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -49,11 +49,8 @@ class TTS(): else: raise Exception(f"Unknown config passed: {config}") - try: - cfg.format( training=False ) - cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing - except Exception as e: - raise e # throw an error because I'm tired of silent errors messing things up for me + cfg.format( training=False ) + cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing if amp is None: amp = cfg.inference.amp @@ -268,7 +265,7 @@ class TTS(): with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): if model_ar is not None: text_list = model_ar( - text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps, + text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps, task_list=["stt"], sampling_temperature=ar_temp, sampling_min_temperature=min_ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p, @@ -318,7 +315,7 @@ class TTS(): with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): if model_ar is not None: resps_list = model_ar( - text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, + text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, task_list=["tts"], input_prompt_prefix=input_prompt_prefix, prefix_silence=prefix_silence, sampling_temperature=ar_temp, @@ -343,7 +340,7 @@ class TTS(): use_lora=use_lora, ) resps_list = model_nar( - text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, + text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"], input_prompt_prefix=input_prompt_prefix, max_levels=max_nar_levels, sampling_temperature=nar_temp, @@ -359,8 +356,8 @@ class TTS(): use_lora=use_lora, ) elif model_len is not None: - len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=5, disable_tqdm=not tqdm ) # don't need more than that - len_list = [ clamp(1, max_ar_steps, l) for l in len_list ] + len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], max_steps=5, disable_tqdm=not tqdm ) # don't need more than that + len_list = [ clamp(l, 1, max_ar_steps) for l in len_list ] kwargs = {} @@ -375,7 +372,7 @@ class TTS(): kwargs["resps_list"] = [ resp[:, :1] ] - resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, + resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"], max_steps=max_ar_steps, max_levels=max_nar_levels, sampling_temperature=nar_temp, diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 6c5c63d..9c7af7d 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -60,25 +60,7 @@ def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ): def get_model(config, training=True, **model_kwargs): name = config.name - if "len" in config.capabilities: - from .nar import NAR - model = NAR( - n_text_tokens=config.text_tokens, - n_audio_tokens=config.audio_tokens, - d_model=config.dim, - n_heads=config.heads, - n_layers=config.layers, - n_experts=config.experts, - - p_dropout=config.dropout, - - l_padding = config.input_alignment, - - training = training, - config = config, - **model_kwargs - ) - elif config.experimental.hf: + if config.experimental.hf: from .experimental import Model as Experimental model = Experimental( n_text_tokens=config.text_tokens, diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py deleted file mode 100644 index 3e7be5f..0000000 --- a/vall_e/models/ar.py +++ /dev/null @@ -1,638 +0,0 @@ -""" -# an AR model that (should) handle: -* handling all RVQ levels, but does it in an autoregressive manner - -It's in a mess of a state, because I want this to be an interleaved model, but it just seems better to use the vall_e.models.experimental model. -""" -from .base import Base, list_to_tensor, Categorical -from ..config import cfg - -import torch -from torch.nn.utils.rnn import pad_sequence - -import random -import math -from einops import rearrange -from torch import Tensor -from tqdm import trange -import logging - -_logger = logging.getLogger(__name__) - -from ..utils import clamp -from ..emb.qnt import trim, encode_as_embedding -from .lora import enable_lora - -class AR(Base): - def forward( - self, - text_list: list[Tensor], - proms_list: list[Tensor], - resps_list: list[Tensor] | None = None, - - task_list: list[Tensor] | None = None, - lang_list: list[Tensor] | None = None, - tone_list: list[Tensor] | None = None, - len_list: list[Tensor] | None = None, - - training: bool | int | None = None, - - max_steps: int = 1000, - max_levels: int = 0, - - input_prompt_prefix: bool = False, - prefix_silence: float = 1.0, - - sampling_temperature: float = 1.0, - sampling_min_temperature: float = -1.0, - sampling_top_k: int = -100, - sampling_top_p: float = 1.0, - sampling_min_p: float = 0.0, - sampling_repetition_penalty: float = 1.0, - sampling_repetition_penalty_decay: float = 0.0, - sampling_length_penalty: float = 0.0, - sampling_beam_width: int = 0, - sampling_mirostat_tau: float = 0.0, - sampling_mirostat_eta: float = 0.1, - sampling_dry_multiplier=0.0, - sampling_dry_base=1.75, - sampling_dry_allowed_length=2, - sampling_entropix=False, - - sampling_layer_skip: bool = False, - sampling_layer_skip_exit_layer: int = -1, - sampling_layer_skip_entropy_threshold: float = -1, - sampling_layer_skip_varentropy_threshold: float = -1, - - sampling_refine_on_stop: bool = False, - - disable_tqdm=False, - use_lora=None, - ): - text_task = [ "stt" ] - - if text_list is not None: - default_task = "tts" - device = text_list[0].device - batch_size = len(text_list) - else: - default_task = "stt" - device = resps_list[0].device - batch_size = len(resps_list) - - # generate task list if not provided - if task_list is None: - task_list = [ default_task for _ in range(batch_size) ] - - has_none = resps_list is None or text_list is None - if not has_none: - for i, task in enumerate( task_list ): - if resps_list[i] is None or text_list[i] is None: - has_none = True - break - - # is training or NAR - if not has_none: - n_levels_set = {r.shape[-1] for r in resps_list} - n_levels = next(iter(n_levels_set)) - - # implicit - if training is None: - training = 0 if n_levels == self.n_resp_levels else None - - # is training - if training is not None: - # specifies how to sample probabilities of which RVQ levels to train against - rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal" - # determines which RVQ level to target per batch - quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] - # rate to perform token dropout errors - token_dropout_error = self.config.experimental.token_dropout_error - # RVQ levels to apply token dropout on - token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels - # implicitly set it to all levels - if not token_dropout_rvq_levels: - token_dropout_rvq_levels = [0, self.resp_levels - 1] - # allow passing a specific distribution of RVQ levels - rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else [] - if not rvq_levels_p: - lo, hi = quant_level_range[0], quant_level_range[1] + 1 - # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - if rvq_levels_p == "equal": - rvq_levels_p = [ i for i in range( lo, hi ) ] - else: - # yuck - rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) - - # input RVQ levels - quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] - for i, task in enumerate( task_list ): - if task in text_task: - quant_levels[i] = 0 # self.n_resp_levels - 1 - - # trim resps to only contain all levels below the target level - resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] - - # tensor to cat for RVQ level 0 - text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16) - audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16) - # I hate python's value/reference semantics so much - for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): - # cap quant_level if it exceeds its corresponding resp/prom - if quant_level >= resps.shape[-1]: - quant_levels[i] = resps.shape[-1] - 1 - - # proms could be a Tensor, list[Tensor], or None - if isinstance( proms, torch.Tensor ): - if quant_level >= proms.shape[-1]: - quant_levels[i] = proms.shape[-1] - 1 - - elif isinstance( proms, list ): - for j, prom in enumerate( proms ): - if not isinstance( prom, torch.Tensor ): - continue - if quant_level >= prom.shape[-1]: - quant_levels[i] = prom.shape[-1] - 1 - - # apply token dropout error compensation - if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): - steps = resps.shape[0] - for l in range( quant_level ): - for t in range( steps ): - token = resps[t, l].item() - - if random.random() < token_dropout_error: - offset = 1 * ( 1 if random.random() < 0.5 else -1 ) - resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 - - # only apply stop token for RVQ level 0 - if quant_level <= 0: - # append stop tokens for AR - if task in text_task: - #text_list[i] = torch.cat([ resps, text_stop_sequence ]) - ... - else: - resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) - - inputs = self.inputs( - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - lang_list=lang_list, - tone_list=tone_list, - task_list=task_list, - - quant_levels=quant_levels, - ) - - return super().forward( - inputs=inputs, - quant_levels=quant_levels, # could technically just grab this from the above inputs since they're included as an RVQ level token - ) - - # is AR - if cfg.lora is not None: - enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) - - # STT - start_slice = [ 0 for _ in range(batch_size) ] - sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] - stopped = torch.zeros(batch_size, device=device).bool() - - audio_stop_token = self.stop_token - text_stop_token = 2 - - state = None - mirostat = [ - {"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0} - ] * batch_size if sampling_mirostat_tau > 0.0 else None - - scores = [ 1.0 ] * sampling_beam_width - metrics = [] - - # ick - """ - low_temperature = False # sampling_temperature < 0.6 # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 # - low_temperature_range = cfg.dataset.frames_per_second * 5 - - original_sampling_temperature = sampling_temperature - original_sampling_repetition_penalty = sampling_repetition_penalty - original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay - """ - - sampling_layer_skip_variables = {} if sampling_layer_skip else None - - if sampling_layer_skip: - if sampling_layer_skip_entropy_threshold >= 0: - sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold - if sampling_layer_skip_varentropy_threshold >= 0: - sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold - if sampling_layer_skip_exit_layer >= 0: - sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer - - for i, sequence in enumerate( sequence_list ): - # add to text for STT - if task_list[i] in text_task: - start_slice[i] = 1 - sequence_list[i] = torch.cat([sequence_list[i], torch.tensor([1], dtype=torch.int16, device=device)]) - # treat input prompt as initial resp (by prefixing with the prompt instead) - elif input_prompt_prefix: - start_slice[i] = proms_list[i].shape[0] - sequence_list[i], proms_list[i] = proms_list[i][:, 0], sequence_list[i] - elif prefix_silence > 0: - sequence_list[i] = get_silence(prefix_silence, device=sequence_list[i].device) - sequence_list[i] = sequence_list[i][:, 0] - # start_slice[i] = sequence_list[i].shape[0] - - # get next in sequence - for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): - # it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it - text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] - resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] - - # greedy sampling in the AR *does* work, but requires some quasi-exotic sampling to work around the initial burst of garbage from polluting the rest of the sequence - # naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio - # however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled - # to-do: tune these values, maybe have it factor based on confidence scores or something - """ - if low_temperature: - enabled = n < low_temperature_range - sampling_repetition_penalty = 1.125 if enabled else 1.25 - #sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay - #sampling_temperature = original_sampling_temperature if enabled else 1.0 - """ - - inputs = self.inputs( - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - lang_list=lang_list, - tone_list=tone_list, - len_list=len_list, - task_list=task_list, - quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] - ) - - # to-do: find an elegant way to write this - output = super().forward( - inputs=inputs, - state=state, - - layer_skip_variables=sampling_layer_skip_variables, - - output_attentions=sampling_entropix, - ) - logits, state = output.logits, output.state - - sampled = super().sample( - logits=logits, - prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], - - temperature=sampling_temperature, - min_temperature=sampling_min_temperature, - top_p=sampling_top_p, - top_k=sampling_top_k, - min_p=sampling_min_p, - repetition_penalty=sampling_repetition_penalty, - repetition_penalty_decay=sampling_repetition_penalty_decay, - length_penalty=sampling_length_penalty, - beam_width=sampling_beam_width, - - mirostat=mirostat, - - dry_multiplier=sampling_dry_multiplier, - dry_base=sampling_dry_base, - dry_allowed_length=sampling_dry_allowed_length, - - attentions=output.attentions if sampling_entropix else None, - ) - - r = sampled[0] - - if cfg.experimental: - if sampled.entropy: - metrics.append( sampled.entropy ) - elif sampled.scores: - metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] ) - - if mirostat is not None: - mirostat = sampled.scores - elif sampling_beam_width > 0: - # expand tuple - s = sampled.scores - # first step, expand batch - if batch_size == 1: - batch_size = sampling_beam_width - text_list = text_list * sampling_beam_width - proms_list = proms_list * sampling_beam_width - sequence_list = sequence_list * sampling_beam_width - task_list = task_list * sampling_beam_width - start_slice = start_slice * sampling_beam_width - stopped = torch.zeros(batch_size, device=device).bool() - - scores = [ scores[i] + score for i, score in enumerate(s) ] - - # append tokens - for i, ri in enumerate(r): - task = task_list[i] - stop_token = audio_stop_token if task not in text_task else text_stop_token - if stop_token in ri: - stopped[i] = True - sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) - - # stop token found - # stopped |= r == stop_token - if stopped.all().item(): - break - - # to-do for layerskip / speculative sampling: rerun the last sequence again at max depth - - if metrics: - from ..plot import plot_sample_metrics - filename = "metrics" - if sampling_entropix: - filename += f'[entropix]' - if sampling_layer_skip_exit_layer >= 0: - filename += f'[{sampling_layer_skip_exit_layer+1}]' - - plot_sample_metrics( metrics, filename=f'{filename}.png' ) - - # pick the best scoring candidate - # desu this is always going to be candidate 0 - if sampling_beam_width: - sequence_list = sequence_list[:1] - task_list = task_list[:1] - - # remove stop token - sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)] - # remove - sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ] - - if sampling_refine_on_stop: - # get how much we need to slice from the end - slice_lengths = [ sequence.shape[-1] for sequence in sequence_list ] - # -1 for the stop token - logits = [ logit[-length-1:-1] for logit, length in zip(logits, slice_lengths) ] - # greedy sample from the sequence - refined_list = [ logit.argmax(dim=-1) for logit in logits ] - # to-do: compare scores - # set the "refined" list as the output - sequence_list = refined_list - - return sequence_list - - -def example_usage(): - cfg.trainer.backend = "local" - cfg.hyperparameters.gradient_accumulation_steps = 1 - if cfg.audio_backend == "dac": - cfg.sample_rate = 44_100 - - from functools import partial - from einops import repeat - from tqdm import tqdm - - from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio - from ..engines import Engine, Engines - from ..utils import wrapper as ml - - import numpy as np - import re - - device = "cuda" - - # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) - """ - if "mamba" in cfg.model.arch_type: - cfg.model.resp_levels = 1 - """ - # cfg.model.loss_factors = {} - - def tokenize(content): - return torch.tensor( cfg.tokenizer.encode(content) ) - - def _load_quants(path) -> Tensor: - qnt = np.load(path, allow_pickle=True)[()] - return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16) - - qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") - noise = _load_quants(f"./data/noise.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") - - text_list = [ - tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), - #tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device), - ] - proms_list = [ - qnt[:cfg.dataset.frames_per_second, :].to(device), - #qnt[:cfg.dataset.frames_per_second, :].to(device), - ] - resps_list = [ - qnt[:, :].to(device), - #qnt[:cfg.dataset.frames_per_second, :].to(device), - ] - - text_list = text_list[:1] - proms_list = proms_list[:1] - resps_list = resps_list[:1] - - batch_size = len(text_list) - - # rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise - kwargs = { - 'n_text_tokens': 256, - 'n_audio_tokens': 1024, - - 'd_model': 1024, # 256, # 1024, # 1536 - 'n_heads': 16, # 4, # 16, # 24 - 'n_layers': 12, # 32 - 'n_experts': 1, - - 'p_dropout': 0.1, - - 'l_padding': 8 if cfg.optimizations.fp8 else 0, - - 'config': cfg.model - } - - """ - try: - kwargs['config'] = cfg.model - except Exception as e: - pass - """ - - bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) - tasks = cfg.dataset.tasks_list - - model = AR(**kwargs).to(device) - steps = 75 * len(tasks) * cfg.model.experimental.causal_size - - optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" - scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" - learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None - - if cfg.optimizations.dadaptation: - # do not combine the two - if scheduler == "schedulefree": - scheduler = "" - - learning_rate = 1.0 - - if optimizer == "prodigy": - if learning_rate is None: - learning_rate = 1.0 - - optimizer = ml.Prodigy - elif optimizer == "adagrad": - if learning_rate is None: - learning_rate = 1.0e-2 - - optimizer = ml.Adagrad - elif optimizer == "adamw": - if learning_rate is None: - learning_rate = 1.0e-4 - - optimizer = ml.AdamW - elif optimizer == "sdg": - if learning_rate is None: - learning_rate = 1.0e-4 - - optimizer = ml.SGD - else: - raise ValueError(f"Unrecognized optimizer: {optimizer}") - - _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}") - - optimizer = optimizer(model.parameters(), lr=learning_rate) - - if scheduler == "schedulefree": - if isinstance(optimizer, ml.AdamW): - scheduler = ml.schedulefree.AdamWScheduleFree - elif isinstance(optimizer, ml.SGD): - scheduler = ml.schedulefree.SGDScheduleFree - else: - scheduler = None - - if scheduler is not None: - _logger.info(f"Scheduler: {scheduler}") - optimizer = scheduler( model.parameters(), lr = learning_rate ) - - if cfg.optimizations.replace and cfg.optimizations.linear: - model = ml.replace_linear( model ) - - if cfg.optimizations.replace and cfg.optimizations.embedding: - model = ml.replace_embedding( model ) - - """ - cfg.optimizations.model_offloading = { - "devices": ["cuda:0", "cpu"], - # "limits": [ 0.9, -1 ], - "assign": [[ f'layers.{i}.' for i in range(0,10) ], [ f'layers.{i}.' for i in range(11,12) ] + [ "model.norm" ]], - # "limits": [ 256 * (1024 ** 2), -1 ] - } - """ - - engine = Engine(model=model, optimizer=optimizer) - engines = Engines({"ar": engine}) - engines.setup() - - """ - if cfg.optimizations.model_offloading: - model = ml.offload_model( model, policy=cfg.optimizations.model_offloading ) - """ - - """ - torch.save( { - 'module': model.state_dict() - }, f"./data/{cfg.model.arch_type}.pth" ) - """ - - _logger.info(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") - - @torch.no_grad() - def sample_data(task=None): - texts = [] - proms = [] - resps = [] - - for i in range(batch_size): - if task is None: - task = random.choice(tasks) - - text = text_list[i] - prom = proms_list[i] - resp = resps_list[i] - - # do nothing - if task == "tts": - ... - elif task == "tts-c": - trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) - - prom = resp[:trim_length] - resp = resp[trim_length:] - elif task == "ns" or task == "sr": - # extend the noise to fill the target audio - noise_ext = repeat_extend_audio( noise, resp.shape[0] ) - # create the input prompt by merging the target audio with the noise - prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device ) - # set the target to just be the noise if - if task == "sr": - resp = noise_ext - - # set the text prompt to empty to train without a guided text prompt - if random.random() < 0.5: - text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8) - - texts.append( text.to(device) ) - proms.append( prom.to(device) ) - resps.append( resp.to(device) ) - - return texts, proms, resps - - @torch.inference_mode() - def sample( name, steps=1000, task=None ): - engine.eval() - - texts, proms, resps = sample_data( task ) - - resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 ) - - for i, o in enumerate(resps): - _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device) - - unload_model() - - def train(): - engine.train() - t = trange(steps) - for i in t: - texts, proms, resps = sample_data() - - stats = {"step": i} - stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps) - stats |= {"grad_norm": engine.get_global_grad_norm()} - - tqdm.write(f"{stats}") - - """ - torch.save( { - 'module': model.state_dict() - }, f"./data/{cfg.model.arch_type}.pth" ) - """ - - #sample("init", 5) - train() - - """ - if cfg.optimizations.compile: - model = ml.compile_model(model, backend=cfg.optimizations.compile) - """ - - for task in tasks: - sample("final", task=task) - - engines.quit() - -if __name__ == "__main__": - example_usage() \ No newline at end of file diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 8d1cc2e..7282d79 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -28,8 +28,163 @@ from ..utils import get_devices, setup_logging, timer, clamp from .lora import enable_lora +text_task = [ "stt" ] + class AR_NAR(Base): - def forward( + def forward_train( + self, + text_list: list[Tensor], + proms_list: list[Tensor], + resps_list: list[Tensor], + + task_list: list[Tensor] | None = None, + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + ): + # deduce batch_size + if text_list is not None: + default_task = "tts" + device = text_list[0].device + batch_size = len(text_list) + else: + default_task = "stt" + device = resps_list[0].device + batch_size = len(resps_list) + + # specifies how to sample probabilities of which RVQ levels to train against + rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal" + # determines which RVQ level to target per batch + quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] + # rate to perform token dropout errors + token_dropout_error = self.config.experimental.token_dropout_error + # RVQ levels to apply token dropout on + token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels + # RVQ levels to apply masking training on + masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels + + # force set mask training + if "len" not in self.capabilities: + masking_train_rvq_levels = 0.0 + elif "ar" not in self.capabilities: + masking_train_rvq_levels = 1.0 + + # CFG + cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0 + cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0 + cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0 + # rate to train RVQ level AR-ly or NAR-ly + masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5 + # implicitly set it to all levels + if not token_dropout_rvq_levels: + token_dropout_rvq_levels = [0, self.resp_levels - 1] + if not token_dropout_rvq_levels: + token_dropout_rvq_levels = [0, 0] + + # allow passing a specific distribution of RVQ levels + rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else [] + if not rvq_levels_p: + lo, hi = quant_level_range[0], quant_level_range[1] + 1 + # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + if rvq_levels_p == "equal": + rvq_levels_p = [ i for i in range( lo, hi ) ] + else: + # yuck + rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) + + # input RVQ levels + quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] + # timestep levels (for TTS NAR) + timesteps = [ None for _ in range(batch_size) ] + + for i, task in enumerate( task_list ): + lo, hi = masking_train_rvq_levels[0], masking_train_rvq_levels[1] + if task in text_task: + quant_levels[i] = 0 # self.n_resp_levels - 1 + elif lo <= quant_levels[i] and quant_levels[i] <= hi and random.random() < masking_train_p: + timesteps[i] = random.random() + + # trim resps to only contain all levels below the target level + resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] + + # tensor to cat for RVQ level 0 + text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16) + text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16) + audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16) + # I hate python's value/reference semantics so much + for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): + # cap quant_level if it exceeds its corresponding resp/prom + if quant_level >= resps.shape[-1]: + quant_levels[i] = resps.shape[-1] - 1 + + # proms could be a Tensor, list[Tensor], or None + if isinstance( proms, torch.Tensor ): + if quant_level >= proms.shape[-1]: + quant_levels[i] = proms.shape[-1] - 1 + + elif isinstance( proms, list ): + for j, prom in enumerate( proms ): + if not isinstance( prom, torch.Tensor ): + continue + if quant_level >= prom.shape[-1]: + quant_levels[i] = prom.shape[-1] - 1 + + # apply token dropout error compensation + if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): + steps = resps.shape[0] + for l in range( quant_level ): + for t in range( steps ): + token = resps[t, l].item() + + if random.random() < token_dropout_error: + offset = 1 * ( 1 if random.random() < 0.5 else -1 ) + resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 + + # only apply stop token for RVQ level 0 + if quant_level <= 0: + # append stop tokens for AR + if task in text_task: + #text_list[i] = torch.cat([ resps, text_stop_sequence ]) + ... + else: + resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) + + # apply CFG (should probably only apply to NAR quant level 0) + if task not in text_task + ["len"]: + drop_text = False + drop_audio = False + + if random.random() < cfg_prom_dropout_p: + drop_audio = True + + if random.random() < cfg_cond_dropout_p: + drop_audio = True + drop_text = True + + if drop_text: + text_list[i] = text_start_stop_sequence + + if drop_audio: + proms_list[i] = None + + inputs = self.inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + task_list=task_list, + time_list=timesteps, + + quant_levels=quant_levels, + ) + + return super().forward( + inputs=inputs, + quant_levels=quant_levels, + ) + + def forward_nar( self, text_list: list[Tensor], proms_list: list[Tensor], @@ -47,6 +202,7 @@ class AR_NAR(Base): input_prompt_prefix: bool = False, prefix_silence: float = 1.0, + denoise_start: float = 0.0, sampling_temperature: float = 1.0, sampling_min_temperature: float = -1.0, @@ -74,8 +230,7 @@ class AR_NAR(Base): disable_tqdm=False, use_lora=None, ): - text_task = [ "stt" ] - + # deduce batch_size if text_list is not None: default_task = "tts" device = text_list[0].device @@ -85,99 +240,297 @@ class AR_NAR(Base): device = resps_list[0].device batch_size = len(resps_list) - # generate task list if not provided - if task_list is None: - task_list = [ default_task for _ in range(batch_size) ] + if max_levels == 0: + max_levels = self.n_max_levels - 1 - has_none = resps_list is None or text_list is None - if not has_none: - for i, task in enumerate( task_list ): - if resps_list[i] is None or text_list[i] is None: - has_none = True - break + sampling_layer_skip_variables = {} if sampling_layer_skip else None - # is training or NAR - if not has_none: - n_levels_set = {r.shape[-1] for r in resps_list} - n_levels = next(iter(n_levels_set)) + if sampling_layer_skip: + if sampling_layer_skip_entropy_threshold >= 0: + sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold + if sampling_layer_skip_varentropy_threshold >= 0: + sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold + if sampling_layer_skip_exit_layer >= 0: + sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer - # implicit - if training is None: - training = 0 if n_levels == self.n_resp_levels else None + # inference NAR level 0 + if len_list is not None: + mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device) + prev_list = [ torch.concat([ mask_token for _ in range( resp_len ) ]) for resp_len in len_list ] - # is training - if training is not None: - # specifies how to sample probabilities of which RVQ levels to train against - rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal" - # determines which RVQ level to target per batch - quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] - # rate to perform token dropout errors - token_dropout_error = self.config.experimental.token_dropout_error - # RVQ levels to apply token dropout on - token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels - # implicitly set it to all levels - if not token_dropout_rvq_levels: - token_dropout_rvq_levels = [0, self.resp_levels - 1] - # allow passing a specific distribution of RVQ levels - rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else [] - if not rvq_levels_p: - lo, hi = quant_level_range[0], quant_level_range[1] + 1 - # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - if rvq_levels_p == "equal": - rvq_levels_p = [ i for i in range( lo, hi ) ] - else: - # yuck - rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) + # special "scheduling" to inference RVQ-level 0 + level = 0 + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) - # input RVQ levels - quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] - for i, task in enumerate( task_list ): - if task in text_task: - quant_levels[i] = 0 # self.n_resp_levels - 1 + def log(x, eps = 1e-20): + return torch.log(x.clamp(min = eps)) + + def gumbel_sample(x, temperature = 1., dim = -1): + return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim) + + _super = super() + def demask_sampling( batch_index, seq_len ): + # overrides + max_steps = 10 + temperature = 0.3 + cfg_strength = 1.0 + sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied...... + sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9 + + # if we're denoising from an existing sequence + if denoise_start > 0.0 and resps_list is not None: + start_noise = denoise_start + noise_p = math.cos( start_noise * math.pi * 0.5 ) + mask = torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device ) + input_ids = torch.where( mask, self.stop_token, resps_list[batch_index][:, 0] ) + else: + input_ids = torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token - # trim resps to only contain all levels below the target level - resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] + scores = torch.zeros((seq_len,), dtype=torch.float32, device=device) - # tensor to cat for RVQ level 0 - text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16) - audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16) - # I hate python's value/reference semantics so much - for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): - # cap quant_level if it exceeds its corresponding resp/prom - if quant_level >= resps.shape[-1]: - quant_levels[i] = resps.shape[-1] - 1 + quant_levels = [ level for _ in range(batch_size) ] + prev_list = [ input_ids ] - # proms could be a Tensor, list[Tensor], or None - if isinstance( proms, torch.Tensor ): - if quant_level >= proms.shape[-1]: - quant_levels[i] = proms.shape[-1] - 1 + start_temperature = temperature + start_noise = 0.0 + end_noise = 1.0 - elif isinstance( proms, list ): - for j, prom in enumerate( proms ): - if not isinstance( prom, torch.Tensor ): - continue - if quant_level >= prom.shape[-1]: - quant_levels[i] = prom.shape[-1] - 1 + null_text = torch.tensor([1, 2], device=device, dtype=torch.int16) + null_prom = None - # apply token dropout error compensation - if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): - steps = resps.shape[0] - for l in range( quant_level ): - for t in range( steps ): - token = resps[t, l].item() + for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))): + # anneal temperature + temperature = start_temperature * (steps_until_x0 / max_steps) + # get noise level, per cosine scheduling + noise_p = math.cos( timestep * math.pi * 0.5 ) + # number of tokens to mask off to "noise" the input sequence + masked_tokens_n = max(int( noise_p * seq_len ), 1) + # pick the worst scoring tokens to mask off + masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices + # mask off inputs + input_ids = input_ids.scatter(0, masked_indices, self.stop_token) + # boolean mask + is_masked = input_ids == self.stop_token + # setup inputs - if random.random() < token_dropout_error: - offset = 1 * ( 1 if random.random() < 0.5 else -1 ) - resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 + inputs = _super.inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=[ input_ids ], + lang_list=lang_list, + tone_list=tone_list, + time_list=[ timestep ], + quant_levels=quant_levels, + ) + output = _super.forward( + inputs=inputs, + quant_levels=quant_levels, + #layer_skip_variables=sampling_layer_skip_variables, + ) - # only apply stop token for RVQ level 0 - if quant_level <= 0: - # append stop tokens for AR - if task in text_task: - #text_list[i] = torch.cat([ resps, text_stop_sequence ]) - ... - else: - resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) + logits = output.logits + + if cfg_strength > 0: + null_inputs = _super.inputs( + text_list=[ null_text ], + proms_list=[ null_prom ], + resps_list=[ input_ids ], + lang_list=lang_list, + tone_list=tone_list, + time_list=[ timestep ], + quant_levels=quant_levels, + ) + null_output = _super.forward( + inputs=null_inputs, + quant_levels=quant_levels, + #layer_skip_variables=sampling_layer_skip_variables, + ) + for logit, null_logits in zip(output.logits, null_output.logits): + logit[-seq_len:] = logit[-seq_len:] + ( logit[-seq_len:] - null_logits[-seq_len:] ) * cfg_strength + + # sample with sampler settings + filtered_sampled = _super.sample( + logits=logits, + prev_list=prev_list, + quant_levels=quant_levels, + + temperature=temperature, + min_temperature=sampling_min_temperature, + top_p=sampling_top_p, + top_k=sampling_top_k, + min_p=sampling_min_p, + repetition_penalty=sampling_repetition_penalty, + repetition_penalty_decay=sampling_repetition_penalty_decay, + length_penalty=sampling_length_penalty, + ) + + # retrieves unfiltered logits + unfiltered_sampled = _super.sample( + logits=logits, + prev_list=prev_list, + quant_levels=quant_levels, + temperature=0.0, + ) + # update previous list of tokens + prev_list = [ input_ids ] + + # extract logits + filtered_logits = filtered_sampled.logits[0] + unfiltered_logits = unfiltered_sampled.logits[0] + + # extract scores + filtered_scores = filtered_sampled.scores[0] + unfiltered_scores = unfiltered_sampled.scores[0] + + # extract sampled tokens + filtered_tokens = filtered_sampled[0][0] + unfiltered_tokens = unfiltered_sampled[0][0] + + # sample with gumbelnoise + # I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model + sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 ) + #sampled_ids = filtered_tokens + + # keep unmasked tokens + input_ids = torch.where( is_masked, sampled_ids, input_ids ) + # update scores (conjugated to put the worst scores at the top) + scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device) + + if cfg.experimental: + print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores ) + + return input_ids + + # perform demasked sampling (mock diffusion) + resps_list = [ demask_sampling( batch_index=i, seq_len=l ) for i, l in enumerate( len_list ) ] + + # expand if given a raw 1D tensor + for i, resp in enumerate(resps_list): + if resp.dim() == 1: + resps_list[i] = resp.unsqueeze(-1) + + prev_list = resps_list + + for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): + level = prev_list[0].shape[-1] + if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels + break + + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) + + quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level) + + inputs = self.inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=prev_list, + lang_list=lang_list, + tone_list=tone_list, + quant_levels=quant_levels, + ) + + output = super().forward( + inputs=inputs, + quant_levels=quant_levels, + #layer_skip_variables=sampling_layer_skip_variables, + ) + logits, state = output.logits, output.state + + sampled = super().sample( + logits=logits, + prev_list=prev_list, + quant_levels=quant_levels, + + temperature=sampling_temperature, + #min_temperature=sampling_min_temperature, + #top_p=sampling_top_p, + #top_k=sampling_top_k, + #min_p=sampling_min_p, + #repetition_penalty=sampling_repetition_penalty, + #repetition_penalty_decay=sampling_repetition_penalty_decay, + #length_penalty=sampling_length_penalty, + #beam_width=sampling_beam_width, + #mirostat=mirostat, + ) + + resps_list = sampled[0] + prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] + + return prev_list + + def forward_ar( + self, + + text_list: list[Tensor], + proms_list: list[Tensor], + resps_list: list[Tensor] | None = None, + + task_list: list[Tensor] | None = None, + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + + training: bool | int | None = None, + + max_steps: int = 1000, + max_levels: int = 0, + + input_prompt_prefix: bool = False, + prefix_silence: float = 1.0, + denoise_start: float = 0.0, + + sampling_temperature: float = 1.0, + sampling_min_temperature: float = -1.0, + sampling_top_k: int = -100, + sampling_top_p: float = 1.0, + sampling_min_p: float = 0.0, + sampling_repetition_penalty: float = 1.0, + sampling_repetition_penalty_decay: float = 0.0, + sampling_length_penalty: float = 0.0, + sampling_beam_width: int = 0, + sampling_mirostat_tau: float = 0.0, + sampling_mirostat_eta: float = 0.1, + sampling_dry_multiplier=0.0, + sampling_dry_base=1.75, + sampling_dry_allowed_length=2, + sampling_entropix=False, + + sampling_layer_skip: bool = False, + sampling_layer_skip_exit_layer: int = -1, + sampling_layer_skip_entropy_threshold: float = -1, + sampling_layer_skip_varentropy_threshold: float = -1, + + sampling_refine_on_stop: bool = False, + + disable_tqdm=False, + use_lora=None, + ): + # deduce batch_size + if text_list is not None: + default_task = "tts" + device = text_list[0].device + batch_size = len(text_list) + else: + default_task = "stt" + device = resps_list[0].device + batch_size = len(resps_list) + + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) + + # inference len + if task_list is not None and task_list[0] == "len": + sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ] + stopped = torch.zeros(batch_size, device=device).bool() + + stop_token = 10 + task_list = [ "len" for _ in range(batch_size) ] + quant_levels = [ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] + + for n in trange(10, desc="AR", disable=disable_tqdm): + len_list = sequence_list inputs = self.inputs( text_list=text_list, @@ -185,89 +538,36 @@ class AR_NAR(Base): resps_list=resps_list, lang_list=lang_list, tone_list=tone_list, + len_list=len_list, task_list=task_list, - - quant_levels=quant_levels, - ) - - return super().forward( - inputs=inputs, - quant_levels=quant_levels, # could technically just grab this from the above inputs since they're included as an RVQ level token - ) - - # is NAR - if max_levels == 0: - max_levels = self.n_max_levels - 1 - - # expand if given a raw 1D tensor - for i, resp in enumerate(resps_list): - if resp.dim() == 1: - resps_list[i] = resp.unsqueeze(-1) - - prev_list = resps_list - - sampling_layer_skip_variables = {} if sampling_layer_skip else None - - if sampling_layer_skip: - if sampling_layer_skip_entropy_threshold >= 0: - sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold - if sampling_layer_skip_varentropy_threshold >= 0: - sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold - if sampling_layer_skip_exit_layer >= 0: - sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer - - for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): - level = prev_list[0].shape[-1] - if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels - break - - if cfg.lora is not None: - enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) - - quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level) - - inputs = self.inputs( - text_list=text_list, - proms_list=proms_list, - resps_list=prev_list, - lang_list=lang_list, - tone_list=tone_list, quant_levels=quant_levels, ) output = super().forward( inputs=inputs, quant_levels=quant_levels, - - layer_skip_variables=sampling_layer_skip_variables, ) - logits, state = output.logits, output.state + logits = output.logits - sampled = super().sample( - logits=logits, - prev_list=prev_list, - quant_levels=quant_levels, + r = [ logit[-1:].argmax(dim=1) for logit in logits ] + # sanitize + for i, token in enumerate(r): + if token > 10: + r[i][0] = stop_token - temperature=sampling_temperature, - #min_temperature=sampling_min_temperature, - #top_p=sampling_top_p, - #top_k=sampling_top_k, - #min_p=sampling_min_p, - #repetition_penalty=sampling_repetition_penalty, - #repetition_penalty_decay=sampling_repetition_penalty_decay, - #length_penalty=sampling_length_penalty, - #beam_width=sampling_beam_width, - #mirostat=mirostat, - ) + # append tokens + for i, ri in enumerate(r): + if stop_token in ri: + stopped[i] = True + sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) - resps_list = sampled[0] - prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] + # stop token found + stopped |= r == stop_token + if stopped.all().item(): + break - return prev_list - - # is AR - if cfg.lora is not None: - enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) + # convert tokens into int + return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ] # STT start_slice = [ 0 for _ in range(batch_size) ] @@ -352,9 +652,7 @@ class AR_NAR(Base): output = super().forward( inputs=inputs, state=state, - - layer_skip_variables=sampling_layer_skip_variables, - + #layer_skip_variables=sampling_layer_skip_variables, output_attentions=sampling_entropix, ) logits, state = output.logits, output.state @@ -457,10 +755,144 @@ class AR_NAR(Base): return sequence_list + def forward( + self, + text_list: list[Tensor], + proms_list: list[Tensor], + resps_list: list[Tensor] | None = None, + + task_list: list[Tensor] | None = None, + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + + training: bool | int | None = None, + + max_steps: int = 1000, + max_levels: int = 0, + + input_prompt_prefix: bool = False, + prefix_silence: float = 1.0, + denoise_start: float = 0.0, + + sampling_temperature: float = 1.0, + sampling_min_temperature: float = -1.0, + sampling_top_k: int = -100, + sampling_top_p: float = 1.0, + sampling_min_p: float = 0.0, + sampling_repetition_penalty: float = 1.0, + sampling_repetition_penalty_decay: float = 0.0, + sampling_length_penalty: float = 0.0, + sampling_beam_width: int = 0, + sampling_mirostat_tau: float = 0.0, + sampling_mirostat_eta: float = 0.1, + sampling_dry_multiplier=0.0, + sampling_dry_base=1.75, + sampling_dry_allowed_length=2, + sampling_entropix=False, + + sampling_layer_skip: bool = False, + sampling_layer_skip_exit_layer: int = -1, + sampling_layer_skip_entropy_threshold: float = -1, + sampling_layer_skip_varentropy_threshold: float = -1, + + sampling_refine_on_stop: bool = False, + + disable_tqdm=False, + use_lora=None, + ): + kwargs = dict( + max_steps=max_steps, + max_levels=max_levels, + input_prompt_prefix=input_prompt_prefix, + prefix_silence=prefix_silence, + denoise_start=denoise_start, + sampling_temperature=sampling_temperature, + sampling_min_temperature=sampling_min_temperature, + sampling_top_k=sampling_top_k, + sampling_top_p=sampling_top_p, + sampling_min_p=sampling_min_p, + sampling_repetition_penalty=sampling_repetition_penalty, + sampling_repetition_penalty_decay=sampling_repetition_penalty_decay, + sampling_length_penalty=sampling_length_penalty, + sampling_beam_width=sampling_beam_width, + sampling_mirostat_tau=sampling_mirostat_tau, + sampling_mirostat_eta=sampling_mirostat_eta, + sampling_dry_multiplier=sampling_dry_multiplier, + sampling_dry_base=sampling_dry_base, + sampling_dry_allowed_length=sampling_dry_allowed_length, + sampling_entropix=sampling_entropix, + sampling_layer_skip=sampling_layer_skip, + sampling_layer_skip_exit_layer=sampling_layer_skip_exit_layer, + sampling_layer_skip_entropy_threshold=sampling_layer_skip_entropy_threshold, + sampling_layer_skip_varentropy_threshold=sampling_layer_skip_varentropy_threshold, + sampling_refine_on_stop=sampling_refine_on_stop, + disable_tqdm=disable_tqdm, + use_lora=use_lora, + ) + + # deduce batch_size + if text_list is not None: + default_task = "tts" + device = text_list[0].device + batch_size = len(text_list) + else: + default_task = "stt" + device = resps_list[0].device + batch_size = len(resps_list) + + # generate task list if not provided + if task_list is None: + task_list = [ default_task for _ in range(batch_size) ] + + # implicitly set for training + if training is None and text_list is not None and resps_list is not None: + n_levels_set = {r.shape[-1] for r in resps_list} + n_levels = next(iter(n_levels_set)) + + training = n_levels == self.n_resp_levels + + # is training + if training: + return self.forward_train( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + task_list=task_list, + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + ) + + # is NAR + if (len_list is not None or resps_list is not None) and text_list is not None: + return self.forward_nar( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + task_list=task_list, + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + **kwargs, + ) + + # is AR + return self.forward_ar( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + task_list=task_list, + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + **kwargs, + ) + def example_usage(): + cfg.device = "cuda" cfg.trainer.backend = "local" - cfg.hyperparameters.gradient_accumulation_steps = 1 if cfg.audio_backend == "dac": cfg.sample_rate = 44_100 @@ -477,33 +909,23 @@ def example_usage(): import re setup_logging() - device = "cuda" - - - # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) - """ - if "mamba" in cfg.model.arch_type: - cfg.model.resp_levels = 1 - """ - # cfg.model.loss_factors = {} def load_artifact( path ): artifact = np.load(path, allow_pickle=True)[()] - text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) - audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) + text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=cfg.device) + audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=cfg.device) return text, audio text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") + batch_size = cfg.hyperparameters.batch_size + cfg.model.experimental.masking_train_p = 0.5 - text_list = [ text ] - proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] - resps_list = [ audio ] + text_list = [ text ] * batch_size + proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size + resps_list = [ audio ] * batch_size - batch_size = len(text_list) - - # rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise kwargs = { 'n_text_tokens': 256, 'n_audio_tokens': 1024, @@ -519,20 +941,12 @@ def example_usage(): 'config': cfg.model } - - """ - try: - kwargs['config'] = cfg.model - except Exception as e: - pass - """ bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) - #available_tasks = cfg.dataset.tasks_list - available_tasks = ["tts"] # , "stt"] + available_tasks = ["tts-ar", "tts-nar"] - model = AR_NAR(**kwargs).to(device) - steps = 500 # 150 * len(available_tasks) # * cfg.model.experimental.causal_size + model = AR_NAR(**kwargs).to(cfg.device) + steps = 500 // batch_size optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" @@ -620,9 +1034,9 @@ def example_usage(): def sample_data(t=None): if isinstance(t, list): tasks = t - texts = [ text_list[0].to(device) if task != "stt" else None for i, task in enumerate( tasks ) ] - proms = [ proms_list[0].to(device) if task != "stt" else [ "stt" ] for i, task in enumerate( tasks ) ] - resps = [ None if task != "stt" else resps_list[0].to(device) for i, task in enumerate( tasks ) ] + texts = [ text_list[0].to(cfg.device) if task not in text_task else None for i, task in enumerate( tasks ) ] + proms = [ proms_list[0].to(cfg.device) if task not in text_task else [ "stt" ] for i, task in enumerate( tasks ) ] + resps = [ None if task not in text_task else resps_list[0].to(cfg.device) for i, task in enumerate( tasks ) ] return texts, proms, resps, tasks @@ -634,45 +1048,15 @@ def example_usage(): for i in range(batch_size): task = random.choice(available_tasks) if t is None else t - text = text_list[i].to(device) - prom = proms_list[i].to(device) - resp = resps_list[i].to(device) + text = text_list[i].to(cfg.device) + prom = proms_list[i].to(cfg.device) + resp = resps_list[i].to(cfg.device) # do nothing - if task == "tts": - ... - elif task == "stt": - prom = [ - task - ] - # to-do: reimplement this from data.py - """ - elif task == "tts-c": - trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) - - prom = resp[:trim_length] - resp = resp[trim_length:] - - prom = prom.to(device) - elif task == "ns" or task == "sr": - # extend the noise to fill the target audio - noise_ext = repeat_extend_audio( noise, resp.shape[0] ) - # create the input prompt by merging the target audio with the noise - prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device ) - prom = prom.to(device) - # set the target to just be the noise if - if task == "sr": - resp = noise_ext - - # set the text prompt to empty to train without a guided text prompt - if random.random() < 0.5: - text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8) - - prom = [ - task, - prom, - ] - """ + if task == "stt": + prom = [ task ] + else: + task = "tts" texts.append( text ) proms.append( prom ) @@ -685,27 +1069,18 @@ def example_usage(): def sample( name, steps=500, task=None ): engine.eval() - texts, proms, resps, tasks = sample_data( task ) + text_list, proms_list, resp_list, task_list = sample_data( task ) - if "ar" in cfg.model.capabilities: - output = engine( texts, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 ) - - text = [ cfg.tokenizer.decode( output[i] ) for i, task in enumerate( tasks ) if task == "stt" ] - - texts = [ texts[i] for i, task in enumerate( tasks ) if task != "stt" ] - proms = [ proms[i] for i, task in enumerate( tasks ) if task != "stt" ] - resps = [ output[i] for i, task in enumerate( tasks ) if task != "stt" ] - tasks = [ tasks[i] for i, task in enumerate( tasks ) if task != "stt" ] - - print( "STT:", text ) + if task == "tts-nar": + len_list = engine(text_list, proms_list, task_list=["len"], max_steps=5, sampling_temperature=0.0 ) + len_list = [ resp_list[0].shape[0] for l in len_list ] + resps_list = engine( text_list, proms_list, len_list=len_list, sampling_temperature=0.0 ) else: - resps = [ resp[:, 0] for resp in resps ] + resps_list = engine( text_list, proms_list, task_list=["tts"], max_steps=steps, sampling_temperature=1.0 ) + resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.0 ) - if "nar" in cfg.model.capabilities: - resps = engine( texts, proms, resps, task_list=tasks, sampling_temperature=0.2 ) - - for i, o in enumerate(resps): - _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device) + for i, o in enumerate(resps_list): + _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.{task}.wav", device=cfg.device) unload_model() @@ -716,7 +1091,7 @@ def example_usage(): texts, proms, resps, tasks = sample_data() stats = {"step": i} - stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks) + stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) stats |= {"grad_norm": engine.get_global_grad_norm()} tqdm.write(f"{stats}") @@ -735,11 +1110,8 @@ def example_usage(): model = ml.compile_model(model, backend=cfg.optimizations.compile) """ - """ for task in available_tasks: sample("final", task=task) - """ - sample("final", task=available_tasks) engines.quit() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index d37c22e..a68eae7 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -246,9 +246,6 @@ class AudioEmbedding(nn.Module): # prom if self.capabilities is None: offset = 0 - # resp - #elif "len" in self.capabilities: - # offset = 1 elif "nar" not in self.capabilities: offset = 0 elif quant_level > 0: @@ -492,16 +489,6 @@ class Base(nn.Module): # +1 to include the stop or mask token n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) - """ - elif "len" not in self.capabilities: - # +1 to include the stop token - n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) - l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) - # NAR-len model - else: - n_resp_tokens = n_audio_tokens - l_tokens = [n_resp_tokens] * (self.n_resp_levels) - """ self.unified_position_ids = unified_position_ids self.interleave = interleave @@ -561,11 +548,11 @@ class Base(nn.Module): # this ***might*** let me also unify the proms_emb and resps_embedding if self.version >= 5: # "len" RVQ level-0 gets an additional token - self.rvq_l_emb = Embedding(self.n_resp_levels + (1 if "len" in self.capabilities else 0), d_model) + self.rvq_l_emb = Embedding(self.n_resp_levels, d_model) # experimental NAR-only mode - self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None - self.time_emb = TimeEmbedding(d_model) if "len" in self.capabilities else None + self.len_emb = Embedding(11, d_model) + self.time_emb = TimeEmbedding(d_model) if attention_backend == "auto": attention_backend = "sdpa" @@ -645,7 +632,7 @@ class Base(nn.Module): use_reentrant=False )) elif self.arch_type == "llama": - LlamaClass = LlamaModel_Adapted if (self.layerskip or "len" in self.capabilities) else LlamaModel + LlamaClass = LlamaModel_Adapted # if (self.layerskip or "len" in self.capabilities) else LlamaModel if n_experts <= 1: self.model = LlamaClass(LlamaConfig( @@ -668,12 +655,6 @@ class Base(nn.Module): # replace with desired attention if attention_backend not in HF_ATTENTIONS: self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) - - # replace with modified Llama - """ - if "len" in self.capabilities: - self.model = ml.replace_attention( self.model, klass=LlamaDecoderLayer_Adapted, target=LlamaDecoderLayer, mode=attention_backend ) - """ else: self.model = MixtralModel(MixtralConfig( vocab_size =n_resp_tokens, @@ -1012,6 +993,7 @@ class Base(nn.Module): for i in range(batch_size): quant_level = quant_levels[i] if quant_levels is not None else 0 task_type = task_list[i] if task_list is not None else "tts" + timestep = time_list[i] if time_list is not None else None # insert task type as a string inputs[i].append( ( "task", task_type ) ) @@ -1023,12 +1005,6 @@ class Base(nn.Module): # Sequence: # prom /may/ include tokens inside to help guide things, per SpeechX if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks: - # pick a random timestep - if "len" in self.capabilities and quant_level == 0: - timestep = random.random() - else: - timestep = 1.0 - # insert the text prompt if text_list is not None and text_list[i] is not None: inputs[i].append( ( "text", text_list[i] ) ) @@ -1045,7 +1021,7 @@ class Base(nn.Module): if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None: inputs[i].append( ( "tone", tone_list[i] ) ) # insert timestep token - if "len" in self.capabilities and quant_level == 0: + if timestep is not None: # store timestep information inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) ) # insert the current output response @@ -1053,7 +1029,7 @@ class Base(nn.Module): inputs[i].append( ( "resp", resps_list[i] ) ) # store dropout mask - if "len" in self.capabilities and quant_level == 0: + if timestep is not None: dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) ) inputs[i].append( ("dropout_mask", dropout_mask ) ) @@ -1072,9 +1048,7 @@ class Base(nn.Module): inputs[i].append( ( "lang", lang_list[i] ) ) # technically will always be level 0 but for the sake of keeing the input formatting coherent... if self.rvq_l_emb is not None: - # override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference) - quant_levels[i] = 0 - inputs[i].append( ( "quant_level", torch.tensor([ self.n_resp_levels ], device=device, dtype=torch.int16) ) ) + inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) # insert input audio prompt if proms_list is not None and proms_list[i] is not None: inputs[i].append( ( "prom", proms_list[i] ) ) @@ -1195,7 +1169,7 @@ class Base(nn.Module): embedding = _interleave_sequence_reshape( embeddings ) # if training NAR-len RVQ level 0 - elif "len" in self.capabilities and quant_level == 0 and dropout_mask is not None: + elif dropout_mask is not None: embedding = self.resps_emb( # if masked use masked token, else original token torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ), @@ -1220,10 +1194,6 @@ class Base(nn.Module): ) else: offset = 0 - """ - if "len" in self.capabilities: - offset = 1 - """ if "nar" not in self.capabilities: offset = 0 elif quant_level > 0: @@ -1264,14 +1234,21 @@ class Base(nn.Module): name, at=None, ): + find_all = at is None + res = [] if at is None else None + for batch_index, batch_input in enumerate(inputs): - if at is not None and batch_index != at: + if not find_all and batch_index != at: continue for n, input in batch_input: - if n == name: + if n != name: + continue + if not find_all: return input - return None + res.append( input ) + + return res # creates position ids from a given input list # if not unified_position_ids, then each input segment will have its own sequence @@ -1401,15 +1378,7 @@ class Base(nn.Module): for i in range(batch_size): quant_level = quant_levels[i] task_name = task_list[i] - - causal = False - - if "len" in self.capabilities: - causal = task_name == "len" - if quant_level >= self.n_resp_levels: - quant_level = 0 - else: - causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) + causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"]) if causal: l = self.causal_size @@ -1487,14 +1456,8 @@ class Base(nn.Module): logit = logits[i][it:it+seq_len] it += seq_len + 1 # +1 to incorporate the separator - - causal = False - if "len" in self.capabilities: - causal = task_name == "len" - if quant_level >= self.n_resp_levels: - quant_level = 0 - else: - causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) + + causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"]) # for the AR, shift sequence so that it predicts the next token # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it) @@ -1854,15 +1817,9 @@ class Base(nn.Module): res = [ Categorical(logits=logit).sample() for logit in logits ] # calculate token probabilities - if "len" in self.capabilities: - scores = [ - [ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ] - for logit, tokens in zip(logits, res) - ] - else: - scores = [ - [ F.softmax(logit[-1, :], dim=-1)[token].item() for token in tokens ] - for logit, tokens in zip(logits, res) - ] + scores = [ + [ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ] + for logit, tokens in zip(logits, res) + ] return Sampled(res, logits, scores, entropy) \ No newline at end of file diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py deleted file mode 100644 index a327c26..0000000 --- a/vall_e/models/nar.py +++ /dev/null @@ -1,672 +0,0 @@ -""" -A (mostly) NAR model that handles inferencing all RVQ levels in parallel (NAR). -I believe Meta's Voicebox does this too (predict the utterance length, then decode in parallel) -It *does* have to inference the initial length in an autoregresssive-ish manner (it can technically also be done in parallel) - -Initial experiments show this only really "works" for the a few brief seconds before going to silence. I imagine I need to read more papers or just need to train longer. -""" - - -import random -import math -import numpy as np -import logging -import torch -from torch.nn.utils.rnn import pad_sequence - -from einops import rearrange -from torch import Tensor -from tqdm import trange - -from .base import Base, list_to_tensor, Categorical, _dropout_mask -from ..config import cfg -from ..emb.qnt import trim, repeat_extend_audio -from ..utils import clamp - -_logger = logging.getLogger(__name__) - -class NAR(Base): - def forward( - self, - text_list: list[Tensor], - proms_list: list[Tensor], - resps_list: list[Tensor] | None = None, - - task_list: list[Tensor] | None = None, - lang_list: list[Tensor] | None = None, - tone_list: list[Tensor] | None = None, - len_list: list[Tensor] | None = None, - - training: bool | int | None = None, - - max_steps: int = 1000, - max_levels: int = 0, - - input_prompt_prefix: bool = False, - prefix_silence: float = 1.0, - denoise_start: float = 0.0, - - sampling_temperature: float = 1.0, - sampling_min_temperature: float = -1.0, - sampling_top_k: int = -100, - sampling_top_p: float = 1.0, - sampling_min_p: float = 0.0, - sampling_repetition_penalty: float = 1.0, - sampling_repetition_penalty_decay: float = 0.0, - sampling_length_penalty: float = 0.0, - sampling_beam_width: int = 0, - sampling_mirostat_tau: float = 0.0, - sampling_mirostat_eta: float = 0.1, - sampling_dry_multiplier=0.0, - sampling_dry_base=1.75, - sampling_dry_allowed_length=2, - sampling_entropix=False, - - sampling_layer_skip: bool = False, - sampling_layer_skip_exit_layer: int = -1, - sampling_layer_skip_entropy_threshold: float = -1, - sampling_layer_skip_varentropy_threshold: float = -1, - - sampling_refine_on_stop: bool = False, - - disable_tqdm=False, - use_lora=None, - ): - text_task = [ "stt" ] - - if text_list is not None: - default_task = "tts" - device = text_list[0].device - batch_size = len(text_list) - else: - default_task = "stt" - device = resps_list[0].device - batch_size = len(resps_list) - - # generate task list if not provided - if task_list is None: - task_list = [ default_task for _ in range(batch_size) ] - - has_none = resps_list is None or text_list is None - if not has_none: - for i, task in enumerate( task_list ): - if resps_list[i] is None or text_list[i] is None: - has_none = True - break - - # is training or NAR - if not has_none: - n_levels_set = {r.shape[-1] for r in resps_list} - n_levels = next(iter(n_levels_set)) - - # implicit - if training is None: - training = 0 if n_levels == self.n_resp_levels else None - - # is training - if training is not None: - len_train_p = self.config.experimental.len_train_p if self.config is not None else 0.05 - - n_levels_set = {r.shape[-1] for r in resps_list} - n_levels = next(iter(n_levels_set)) - - # assert n_levels == self.n_resp_levels - - # to-do: make this YAML configurable - def sample_task(): - return "len" if random.random() < len_train_p else "tts" - - # generate task list to train against - task_list = [ sample_task() for _ in range(batch_size) ] - - # specifies how to sample probabilities of which RVQ levels to train against - rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal" - # determines which RVQ level to target per batch - quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] - # rate to perform token dropout errors - token_dropout_error = self.config.experimental.token_dropout_error - # RVQ levels to apply token dropout on - token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels - # CFG - cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0 - cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0 - cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0 - # implicitly set it to all levels - if not token_dropout_rvq_levels: - token_dropout_rvq_levels = [0, self.resp_levels - 1] - # allow passing a specific distribution of RVQ levels - rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else [] - if not rvq_levels_p: - lo, hi = quant_level_range[0], quant_level_range[1] + 1 - # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - if rvq_levels_p == "equal": - rvq_levels_p = [ i for i in range( lo, hi ) ] - else: - # yuck - rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) - - # input RVQ levels - quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] - for i, task in enumerate( task_list ): - if task in text_task: - quant_levels[i] = 0 # self.n_resp_levels - 1 - - # trim resps to only contain all levels below the target level - resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] - # empty string for CFG - text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16) - # I hate python's value/reference semantics so much - for i, quant_level, text, resps, proms, task in zip(range(batch_size), quant_levels, text_list, resps_list, proms_list, task_list): - # cap quant_level if it exceeds its corresponding resp/prom - if quant_level >= resps.shape[-1]: - quant_levels[i] = resps.shape[-1] - 1 - - # proms could be a Tensor, list[Tensor], or None - if isinstance( proms, torch.Tensor ): - if quant_level >= proms.shape[-1]: - quant_levels[i] = proms.shape[-1] - 1 - - elif isinstance( proms, list ): - for j, prom in enumerate( proms ): - if not isinstance( prom, torch.Tensor ): - continue - if quant_level >= prom.shape[-1]: - quant_levels[i] = prom.shape[-1] - 1 - - # apply token dropout error compensation - if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): - steps = resps.shape[0] - for l in range( quant_level ): - for t in range( steps ): - token = resps[t, l].item() - - if random.random() < token_dropout_error: - offset = 1 * ( 1 if random.random() < 0.5 else -1 ) - resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 - - # only apply stop token for RVQ level 0 - if quant_level <= 0: - # append stop tokens for AR - if task in text_task: - #text_list[i] = torch.cat([ resps, text_stop_sequence ]) - ... - else: - #resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) - ... - - # apply CFG (should probably only apply to NAR quant level 0) - if task not in text_task + ["len"]: - drop_text = False - drop_audio = False - - if random.random() < cfg_prom_dropout_p: - drop_audio = True - - if random.random() < cfg_cond_dropout_p: - drop_audio = True - drop_text = True - - if drop_text: - text_list[i] = text_start_stop_sequence - - if drop_audio: - proms_list[i] = None - - inputs = self.inputs( - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - lang_list=lang_list, - tone_list=tone_list, - task_list=task_list, - - quant_levels=quant_levels, - ) - - return super().forward( - inputs=inputs, - quant_levels=quant_levels, - ) - - - if len_list is not None: - sampling_layer_skip_variables = {} if sampling_layer_skip else None - - if max_levels == 0: - max_levels = self.n_max_levels - 1 - - if sampling_layer_skip: - if sampling_layer_skip_entropy_threshold >= 0: - sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold - if sampling_layer_skip_varentropy_threshold >= 0: - sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold - if sampling_layer_skip_exit_layer >= 0: - sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer - - # initial condition - """ - print( len_list ) - len_list = [ clamp(1, max_steps, l) for l in len_list ] - print( len_list ) - """ - metrics = [] - - mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device) - prev_list = [ torch.concat([ mask_token for _ in range( resp_len ) ]) for resp_len in len_list ] - - # special "scheduling" to inference RVQ-level 0 - level = 0 - if cfg.lora is not None: - enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) - - def log(x, eps = 1e-20): - return torch.log(x.clamp(min = eps)) - - def gumbel_sample(x, temperature = 1., dim = -1): - return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim) - - _super = super() - def demask_sampling( batch_index, seq_len ): - # overrides - max_steps = 10 - temperature = 0.3 - cfg_strength = 1.0 - sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied...... - sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9 - - # if we're denoising from an existing sequence - if denoise_start > 0.0 and resps_list is not None: - start_noise = denoise_start - noise_p = math.cos( start_noise * math.pi * 0.5 ) - mask = torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device ) - input_ids = torch.where( mask, self.stop_token, resps_list[batch_index][:, 0] ) - else: - input_ids = torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token - - scores = torch.zeros((seq_len,), dtype=torch.float32, device=device) - - quant_levels = [ level for _ in range(batch_size) ] - prev_list = [ input_ids ] - - start_temperature = temperature - start_noise = 0.0 - end_noise = 1.0 - - null_text = torch.tensor([1, 2], device=device, dtype=torch.int16) - null_prom = None - - for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))): - # anneal temperature - temperature = start_temperature * (steps_until_x0 / max_steps) - # get noise level, per cosine scheduling - noise_p = math.cos( timestep * math.pi * 0.5 ) - # number of tokens to mask off to "noise" the input sequence - masked_tokens_n = max(int( noise_p * seq_len ), 1) - # pick the worst scoring tokens to mask off - masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices - # mask off inputs - input_ids = input_ids.scatter(0, masked_indices, self.stop_token) - # boolean mask - is_masked = input_ids == self.stop_token - # setup inputs - - inputs = _super.inputs( - text_list=text_list, - proms_list=proms_list, - resps_list=[ input_ids ], - lang_list=lang_list, - tone_list=tone_list, - time_list=[ timestep ], - quant_levels=quant_levels, - ) - output = _super.forward( - inputs=inputs, - quant_levels=quant_levels, - layer_skip_variables=sampling_layer_skip_variables, - ) - - logits = output.logits - - if cfg_strength > 0: - null_inputs = _super.inputs( - text_list=[ null_text ], - proms_list=[ null_prom ], - resps_list=[ input_ids ], - lang_list=lang_list, - tone_list=tone_list, - time_list=[ timestep ], - quant_levels=quant_levels, - ) - null_output = _super.forward( - inputs=null_inputs, - quant_levels=quant_levels, - layer_skip_variables=sampling_layer_skip_variables, - ) - for logit, null_logits in zip(output.logits, null_output.logits): - logit[-seq_len:] = logit[-seq_len:] + ( logit[-seq_len:] - null_logits[-seq_len:] ) * cfg_strength - - # sample with sampler settings - filtered_sampled = _super.sample( - logits=logits, - prev_list=prev_list, - quant_levels=quant_levels, - - temperature=temperature, - min_temperature=sampling_min_temperature, - top_p=sampling_top_p, - top_k=sampling_top_k, - min_p=sampling_min_p, - repetition_penalty=sampling_repetition_penalty, - repetition_penalty_decay=sampling_repetition_penalty_decay, - length_penalty=sampling_length_penalty, - ) - - # retrieves unfiltered logits - unfiltered_sampled = _super.sample( - logits=logits, - prev_list=prev_list, - quant_levels=quant_levels, - temperature=0.0, - ) - # update previous list of tokens - prev_list = [ input_ids ] - - # extract logits - filtered_logits = filtered_sampled.logits[0] - unfiltered_logits = unfiltered_sampled.logits[0] - - # extract scores - filtered_scores = filtered_sampled.scores[0] - unfiltered_scores = unfiltered_sampled.scores[0] - - # extract sampled tokens - filtered_tokens = filtered_sampled[0][0] - unfiltered_tokens = unfiltered_sampled[0][0] - - # sample with gumbelnoise - # I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model - sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 ) - #sampled_ids = filtered_tokens - - # keep unmasked tokens - input_ids = torch.where( is_masked, sampled_ids, input_ids ) - # update scores (conjugated to put the worst scores at the top) - scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device) - - if cfg.experimental: - print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores ) - - return input_ids - - # perform demasked sampling (mock diffusion) - prev_list = [ demask_sampling( batch_index=i, seq_len=l ) for i, l in enumerate( len_list ) ] - - # expand if given a raw 1D tensor - for i, resp in enumerate(prev_list): - if resp.dim() == 1: - prev_list[i] = resp.unsqueeze(-1) - - for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): - level = prev_list[0].shape[-1] - if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels - break - - if cfg.lora is not None: - enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) - - quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level) - - inputs = self.inputs( - text_list=text_list, - proms_list=proms_list, - resps_list=prev_list, - lang_list=lang_list, - tone_list=tone_list, - quant_levels=quant_levels, - ) - - output = super().forward( - inputs=inputs, - quant_levels=quant_levels, - - layer_skip_variables=sampling_layer_skip_variables, - ) - logits, state = output.logits, output.state - - sampled = super().sample( - logits=logits, - prev_list=prev_list, - quant_levels=quant_levels, - - temperature=0.0, # sampling_temperature, - #min_temperature=sampling_min_temperature, - #top_p=sampling_top_p, - #top_k=sampling_top_k, - #min_p=sampling_min_p, - #repetition_penalty=sampling_repetition_penalty, - #repetition_penalty_decay=sampling_repetition_penalty_decay, - #length_penalty=sampling_length_penalty, - #beam_width=sampling_beam_width, - #mirostat=mirostat, - ) - - resps_list = sampled[0] - prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] - - return prev_list - - # is AR - if cfg.lora is not None: - enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) - - sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ] - stopped = torch.zeros(batch_size, device=device).bool() - - stop_token = 10 - task_list = [ "len" for _ in range(batch_size) ] - - for n in trange(10, desc="AR", disable=disable_tqdm): - len_list = sequence_list - - inputs = self.inputs( - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - lang_list=lang_list, - tone_list=tone_list, - len_list=len_list, - task_list=task_list, - quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] - ) - - output = super().forward( - inputs=inputs, - ) - logits = output.logits - - r = [ logit[-1:].argmax(dim=1) for logit in logits ] - # sanitize - for i, token in enumerate(r): - if token > 10: - r[i][0] = stop_token - - # append tokens - for i, ri in enumerate(r): - if stop_token in ri: - stopped[i] = True - sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) - - # stop token found - stopped |= r == stop_token - if stopped.all().item(): - break - - # convert tokens into int - return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ] - - -def example_usage(): - cfg.trainer.backend = "local" - cfg.hyperparameters.gradient_accumulation_steps = 1 - if cfg.audio_backend == "dac": - cfg.sample_rate = 44_100 - - from functools import partial - from einops import repeat - from tqdm import tqdm - - from ..emb.qnt import decode_to_file, unload_model - from ..engines import Engine - from ..utils import wrapper as ml - - import numpy as np - import re - - device = "cuda" - - def load_artifact( path ): - artifact = np.load(path, allow_pickle=True)[()] - - text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) - audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) - - return text, audio - - text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") - - text_list = [ text ] - proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] - resps_list = [ audio ] - - # rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise - kwargs = { - 'n_text_tokens': 256, - 'n_audio_tokens': 1024, - - 'd_model': 1024, # 256, # 1024, # 1536 - 'n_heads': 16, # 4, # 16, # 24 - 'n_layers': 12, # 32 - 'n_experts': 1, - - 'p_dropout': 0.1, - - 'l_padding': 8 if cfg.optimizations.fp8 else 0, - - 'config': cfg.model - } - - """ - try: - kwargs['config'] = cfg.model - except Exception as e: - pass - """ - - model = NAR(**kwargs).to(device) - steps = 250 - - optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" - scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" - learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None - - if cfg.optimizations.dadaptation: - # do not combine the two - if scheduler == "schedulefree": - scheduler = "" - - learning_rate = 1.0 - - if optimizer == "prodigy": - if learning_rate is None: - learning_rate = 1.0 - - optimizer = ml.Prodigy - elif optimizer == "adagrad": - if learning_rate is None: - learning_rate = 1.0e-2 - - optimizer = ml.Adagrad - elif optimizer == "adamw": - if learning_rate is None: - learning_rate = 1.0e-4 - - optimizer = ml.AdamW - elif optimizer == "sdg": - if learning_rate is None: - learning_rate = 1.0e-4 - - optimizer = ml.SGD - else: - raise ValueError(f"Unrecognized optimizer: {optimizer}") - - _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}") - - optimizer = optimizer(model.parameters(), lr=learning_rate) - - if scheduler == "schedulefree": - if isinstance(optimizer, ml.AdamW): - scheduler = ml.schedulefree.AdamWScheduleFree - elif isinstance(optimizer, ml.SGD): - scheduler = ml.schedulefree.SGDScheduleFree - else: - scheduler = None - - if scheduler is not None: - _logger.info(f"Scheduler: {scheduler}") - optimizer = scheduler( model.parameters(), lr = learning_rate ) - - if cfg.optimizations.replace and cfg.optimizations.linear: - model = ml.replace_linear( model ) - - if cfg.optimizations.replace and cfg.optimizations.embedding: - model = ml.replace_embedding( model ) - - engine = Engine(model=model, optimizer=optimizer) - - """ - torch.save( { - 'module': model.state_dict() - }, f"./data/{cfg.model.arch_type}.pth" ) - """ - - _logger.info(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") - - @torch.inference_mode() - def sample( name, steps=1000 ): - if cfg.audio_backend == "dac" and name == "init": - return - - engine.eval() - - len_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) - resps_list = engine( text_list, proms_list, len_list=len_list, sampling_temperature=0.2 ) - - len_list = [ min(l, 500) for l in len_list ] - - for i, o in enumerate(resps_list): - _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device) - - unload_model() - - def train(): - engine.train() - t = trange(steps) - for i in t: - stats = {"step": i} - stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) - stats |= {"grad_norm": engine.get_global_grad_norm()} - - tqdm.write(f"{stats}") - - """ - torch.save( { - 'module': model.state_dict() - }, f"./data/{cfg.model.arch_type}.pth" ) - """ - - #sample("init", 5) - train() - sample("final") - -if __name__ == "__main__": - example_usage() \ No newline at end of file diff --git a/vall_e/samplers.py b/vall_e/samplers.py index baff3ab..ae4ef84 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -47,8 +47,8 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F start = i + 1 # apply either up to limit tokens, or to the end end = start + limit if limit > 0 else seq_len - start = clamp(0, seq_len - 1, start) - end = clamp(0, seq_len - 1, end) + start = clamp(start, 0, seq_len - 1) + end = clamp(end, 0, seq_len - 1) for j in range( start, end ): distance = j - i logits[j, token] /= factor * (distance ** decay)