From 5479d2eacc0a869dc58472d1cac4cc7ffd8ed831 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 18 Mar 2025 19:34:37 -0500 Subject: [PATCH] more tweaks to the new implementation (properly trim the len stuff to save some params, decoder to d_ffn expansion to 2 to maybe also make it faster, etc.) --- vall_e/config.py | 15 +++--- vall_e/models/arch/llama.py | 32 ++++++------ vall_e/models/base_v2.py | 97 ++++++++++++++++++++----------------- vall_e/webui.py | 2 +- 4 files changed, 79 insertions(+), 67 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 951214a..743a03e 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -299,9 +299,11 @@ class ModelExperimentalSettings: len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong) len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16 parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking + layer_dropout_p: float = 0.0 # performs layer dropout, which I readded because it might actually help since the reference model had this at 0.1 + codebook_dropout_p: float = 0.0 # perform codebook dropout, which I added as an explicit feature since it seems the reference model had this in a way + # although I don't know how I could easily implement this for the new implementation - # - logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298 + logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298 (this actually is very bad) per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training # "auto" will pick best for codec @@ -313,16 +315,17 @@ class ModelExperimentalSettings: # this is a flag since I am cautious use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck - # these technically should be as hyperparameters # performs token dropout to compensate for errors + # currently unused, since this might be the wrong way to go about it token_dropout_error: float = 0.0 # probability to nudge a token by ±1 token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0 - # these technically should be as hyperparameters + # classifier-free guidance training settings + # too large might actually be a problem cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training - cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training - cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training + cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training + cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training use_raw_text_p: float = 0.0 # probability to use raw text as the input prompt instead diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 8a17517..821fdc4 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -25,6 +25,7 @@ class Config(BaseConfig): attn_mode = "sdpa", output_norm = True, causal = True, + layer_dropout = 0.0, *args, **kwargs ): super().__init__(*args, **kwargs) @@ -32,6 +33,7 @@ class Config(BaseConfig): self.attn_mode = attn_mode self.output_norm = output_norm self.causal = causal + self.layer_dropout = layer_dropout def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape @@ -156,18 +158,10 @@ class Attention(nn.Module): elif self.attn_mode == "sdpa": self.attn_mode = torch.nn.attention.SDPBackend.MATH - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) + self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) + self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) + self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) + self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) # extracts inputs from a batch based on requested causality def split_forward( @@ -576,6 +570,10 @@ class Model(LlamaPreTrainedModel): self.rotary_emb = RotaryEmbedding(config=config) self.gradient_checkpointing = False + # setup layer dropout LUT + LN_2 = 0.69314718056 + self.layer_dropouts = [ (math.exp((l * LN_2) / (self.layers_n - 1)) - 1) * self.config.layer_dropout for l in range(self.layers_n) ] + # Initialize weights and apply final processing self.post_init() @@ -725,6 +723,9 @@ class Model(LlamaPreTrainedModel): return causal_mask + def dropout_layer( self, l ): + return random.random() < self.layer_dropouts[l] if self.training else False + def forward( self, input_ids: torch.LongTensor = None, @@ -829,10 +830,11 @@ class Model(LlamaPreTrainedModel): position_embeddings=position_embeddings, ) - hidden_states = layer_outputs[0] + if not self.dropout_layer( l ): + hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 5cd90fe..55a2eb9 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -89,20 +89,25 @@ class FiniteAudioEncoder(nn.Module): n_tokens: int, n_levels: int, token_dim: int, - use_ln: bool = True, - use_ffn: bool = True, - training: bool = True, + use_ln: bool = True, # whether to perform a post-embedding pre-norm or not (I'm not sure if this is redundant) + use_ffn: bool = True, # whether to employ a residual feed forward network or not + + d_model: int = None, + d_ffn: int = 2, # feed forward expansion value ): super().__init__() + + if not d_model: + d_model = token_dim + self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)]) self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02) self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity() self.proj = nn.Sequential( - nn.Linear(token_dim, token_dim * 2), + nn.Linear(token_dim, token_dim * d_ffn), nn.GELU(), - nn.Linear(token_dim * 2, token_dim), - #nn.Dropout(0.1 if training else 0.0) - ) if use_ffn else nn.Linear(token_dim, token_dim) + nn.Linear(token_dim * d_ffn, d_model), + ) if use_ffn else nn.Linear(token_dim, d_model) self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels)) @@ -151,35 +156,50 @@ class FiniteAudioDecoder(nn.Module): d_model: int, vocab_size: int, n_levels: int, - d_ffn: int = 4, - use_ln: bool = True, - shared_levels: bool = False, - training: bool = False, + + d_ffn: int = 2, # feed forward expansion value + use_ln: bool = True, # perform layer normalization here + use_ffn: bool = True, # use a feed forward network post-norm pre-classifier + shared_levels: bool = False, # whether to have one set of weights for all codebook levels, or separate weights for each layer ): super().__init__() self.n_levels = n_levels self.shared_levels = shared_levels - if not shared_levels: - self.head = nn.ModuleList([nn.Sequential( - # ln - (nn.LayerNorm(d_model) if use_ln else nn.Identity()), - # ffn - nn.Linear(d_model, d_ffn * d_model), - nn.GELU(), - nn.Linear(d_ffn * d_model, d_model), - # head - nn.Linear(d_model, vocab_size) - ) for _ in range(n_levels)]) + if use_ffn: + if not shared_levels: + self.head = nn.ModuleList([nn.Sequential( + # ln + (nn.LayerNorm(d_model) if use_ln else nn.Identity()), + # ffn + nn.Linear(d_model, d_ffn * d_model), + nn.GELU(), + nn.Linear(d_ffn * d_model, d_model), + # head + nn.Linear(d_model, vocab_size) + ) for _ in range(n_levels)]) + else: + self.head = nn.Sequential( + # ffn + nn.Linear(d_model, d_ffn * d_model), + nn.GELU(), + nn.Linear(d_ffn * d_model, d_model), + # head + nn.Linear(d_model, vocab_size * n_levels) + ) else: - self.head = nn.Sequential( - # ffn - nn.Linear(d_model, d_ffn * d_model), - nn.GELU(), - nn.Linear(d_ffn * d_model, d_model), - # head - nn.Linear(d_model, vocab_size * n_levels) - ) + if not shared_levels: + self.head = nn.ModuleList([nn.Sequential( + # ln + (nn.LayerNorm(d_model) if use_ln else nn.Identity()), + # head + nn.Linear(d_model, vocab_size) + ) for _ in range(n_levels)]) + else: + self.head = nn.Sequential( + # head + nn.Linear(d_model, vocab_size * n_levels) + ) def forward(self, x: Tensor) -> Tensor: batch_size, seq_len, _ = x.shape @@ -376,50 +396,37 @@ class Base_V2(nn.Module): self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None - self.len_emb = ml.Embedding(11, d_model) # unused + self.len_emb = None # ml.Embedding(11, d_model) self.audio_emb = None self.proms_emb = None self.resps_emb = None - """ - if n_audio_tokens == 1000 or: - AudioEncoder = FiniteAudioEncoder - AudioDecoder = FiniteAudioDecoder - else: - AudioEncoder = ResidualAudioEncoder - AudioDecoder = ResidualAudioDecoder - """ - if monolithic_audio_encoder: self.audio_emb = AudioEncoder( n_tokens=n_audio_tokens + 2, # stop + masked token n_levels=self.n_resp_levels, token_dim=d_model, - training=training, ) else: self.proms_emb = AudioEncoder( n_tokens=n_audio_tokens, n_levels=self.n_resp_levels, token_dim=d_model, - training=training, ) self.resps_emb = AudioEncoder( n_tokens=n_audio_tokens + 2, # stop + masked token n_levels=self.n_resp_levels, token_dim=d_model, - training=training, ) self.audio_decoder = AudioDecoder( d_model, (n_audio_tokens + 1), self.n_resp_levels, - training=training, use_ln=per_level_normalization, ) - self.len_decoder = AuxDecoder( d_model, 11 ) # to-do: adjust this + self.len_decoder = AuxDecoder( d_model, 1 ) self.phn_decoder = AuxDecoder( d_model, n_phn_tokens ) self.text_decoder = AuxDecoder( d_model, n_text_tokens ) diff --git a/vall_e/webui.py b/vall_e/webui.py index ba83886..8b730e8 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -85,7 +85,7 @@ def gradio_wrapper(inputs): return decorated # returns a list of models, assuming the models are placed under ./training/ or ./models/ or ./data/models/ -def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data/models/")] ): +def get_model_paths(paths=[Path("./training/"), Path("./models/"), Path("./data/models/")] ): configs = [] for path in paths: