diff --git a/vall_e/config.py b/vall_e/config.py index 657e64a..6bfcd52 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -214,6 +214,11 @@ class ModelExperimentalSettings: rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary unified_position_ids: bool = True # False will generate position IDs partitioned for each section tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing + + # performs token dropout to compensate for errors + token_dropout_error: float = 0.0 # probability to nudge a token by ±1 + token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value + token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0 # I really need to clean this up @dataclass() diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index f6d17be..0d15863 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -22,6 +22,9 @@ from ..emb.qnt import trim, encode_as_embedding from .lora import enable_lora +def clamp(n, lo, hi): + return max(lo, min(n, hi)) + class AR_NAR(Base): @property def capabilities(self) -> list[str]: @@ -139,6 +142,11 @@ class AR_NAR(Base): # determines which RVQ level to target per batch quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ] + token_dropout_error = self.config.experimental.token_dropout_error + token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels + if not token_dropout_rvq_levels: + token_dropout_rvq_levels = [0, self.resp_levels] + if p_rvq_levels == "equal": # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] @@ -165,39 +173,49 @@ class AR_NAR(Base): quant_levels = [ random.choice( pool ) for i in range(batch_size) ] # these two are techinically equivalent if the audio embeddings handle things properly + """ resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] stop_sequence = torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) - """ + resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16) - """ - - for i in range(batch_size): + + for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list): # cap quant_level if it exceeds its corresponding resp/prom - if quant_levels[i] >= resps_list[i].shape[-1]: - quant_levels[i] = resps_list[i].shape[-1] - 1 + if quant_level >= resps.shape[-1]: + quant_levels[i] = resps.shape[-1] - 1 - # proms_list[i] could be a Tensor, list[Tensor], or None - if isinstance( proms_list[i], torch.Tensor ): - if quant_levels[i] >= proms_list[i].shape[-1]: - quant_levels[i] = proms_list[i].shape[-1] - 1 + # proms could be a Tensor, list[Tensor], or None + if isinstance( proms, torch.Tensor ): + if quant_level >= proms.shape[-1]: + quant_levels[i] = proms.shape[-1] - 1 - elif isinstance( proms_list[i], list ): - for j, prom in enumerate( proms_list[i] ): + elif isinstance( proms, list ): + for j, prom in enumerate( proms ): if not isinstance( prom, torch.Tensor ): continue - if quant_levels[i] >= prom.shape[-1]: + if quant_level >= prom.shape[-1]: quant_levels[i] = prom.shape[-1] - 1 - # only apply stop token for RVQ level 0 - if quant_levels[i] > 0: - continue + # apply token dropout error compensation + if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): + steps = resps.shape[0] + for l in range( quant_level ): + for t in range( steps ): + token = resps[t, l].item() - # append stop tokens for AR - # could technically do it in the .inputs call - resps_list[i] = torch.cat([ resps_list[i], stop_sequence ]) + if random.random() < token_dropout_error: + offset = 1 * ( 1 if random.random() < 0.5 else -1 ) + resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 + + # only apply stop token for RVQ level 0 + if quant_level <= 0: + # append stop tokens for AR + # could technically do it in the .inputs call + resps_list[i] = torch.cat([ resps, stop_sequence ]) + inputs = self.inputs( text_list=text_list, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 09ba222..691ce22 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -12,7 +12,7 @@ Additional functionality (preparing inputs, generating full audio) should be del import math import torch import torch.nn.functional as F -import traceback +import random import numpy as np import re @@ -439,6 +439,10 @@ class Base(nn.Module): self.tasks_emb = None self.rvq_l_emb = None self.len_emb = None + + # it would be nicer for these to be a token or live inside an embedding + self.sep = nn.Parameter(torch.randn(d_model)) + self.dropout_token = nn.Parameter(torch.zeros(d_model)) # zeros sounds nicer than randn for a special value if self.version == 1: # legacy n_audio_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom @@ -484,9 +488,6 @@ class Base(nn.Module): # experimental NAR-only mode self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None - # this would be nicer to be a stop token or live inside an embedding - self.sep = nn.Parameter(torch.randn(d_model)) - # ick, there has to be a better way hf_attention = self.config.attention if self.config is not None else None @@ -970,6 +971,16 @@ class Base(nn.Module): return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], offset = 0 ) + # yuck + token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0 + token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels if self.config else None + + if self.dropout_token is None or not self.training: + token_dropout_rate = 0.0 + + if not token_dropout_rvq_levels: + token_dropout_rvq_levels = [1, self.resp_levels] + x_list = [] for batch_index, batch_input in enumerate(inputs): batch = [] @@ -1018,6 +1029,16 @@ class Base(nn.Module): input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 if quant_level == 0 or "len" in self.capabilities else 1 ) + + # apply token dropout + if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): + steps = embedding.shape[0] - (1 if quant_level == 0 else 0) # do not mess with stop token + for i in range( steps ): + if random.random() > token_dropout_rate: + continue + + embedding[i] = self.dropout_token + elif name == "len" and self.len_emb is not None: embedding = self.len_emb( input ) else: