From 814146a5e0d812e76ed87d776c57952d402edb1c Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 12 Apr 2025 12:53:44 -0500 Subject: [PATCH] more settings bloat because there seems to be instability with the encoder as-is --- vall_e/config.py | 3 +++ vall_e/models/base_v2.py | 56 ++++++++++++++++++++++++++-------------- 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index b4fbd38..d7fc2c3 100644 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -317,6 +317,9 @@ class ModelExperimentalSettings: use_sliding_attention_mask: bool = False # when used with above, applies a sliding mask within the current segment # this is a flag since I am cautious use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck + use_audio_encoder_level_weights: bool = True # flag to maintain backwards compat + use_audio_encoder_ffn: bool = True # + use_audio_encoder_norm: bool = True # audio_decoder_ffn_expansion_size: int = 2 # need to do something awful with this audio_encoder_ffn_expansion_size: int = 2 # need to do something awful with this diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 7b3611f..66d387e 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -89,8 +89,10 @@ class FiniteAudioEncoder(nn.Module): n_tokens: int, n_levels: int, token_dim: int, + monolithic: bool = False, use_ln: bool = True, # whether to perform a post-embedding pre-norm or not (I'm not sure if this is redundant) use_ffn: bool = True, # whether to employ a residual feed forward network or not + use_level_weights: bool = False, d_model: int = None, d_ffn: int = 2, # feed forward expansion value @@ -101,8 +103,12 @@ class FiniteAudioEncoder(nn.Module): d_model = token_dim self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for _ in range(n_levels)]) - self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02) + + # there needs to be some information when separating between the proms and the resps + self.pos_embedding = nn.Parameter(torch.randn(2 if monolithic else 1, n_levels, token_dim) * 0.02) + self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity() + if use_ffn: self.proj = nn.Sequential( nn.Linear(token_dim, token_dim * d_ffn), @@ -114,16 +120,9 @@ class FiniteAudioEncoder(nn.Module): else: self.proj = nn.Identity() - self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels)) + self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels)) if use_level_weights else None self.use_ffn = use_ffn - # explicit initialization - # this is actually BAD BAD BAD - """ - for emb in self.embs: - torch.nn.init.normal_(emb.weight, mean=0.0, std=0.02) - """ - if use_ffn: nn.init.xavier_uniform_(self.proj[0].weight) nn.init.xavier_uniform_(self.proj[2].weight) @@ -134,7 +133,7 @@ class FiniteAudioEncoder(nn.Module): nn.init.xavier_uniform_(self.proj.weight) nn.init.zeros_(self.proj.bias) - def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None, stability = None ) -> Tensor: + def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None, stability = None, mode = None ) -> Tensor: # empty if xi.shape[0] == 0: dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0] @@ -147,14 +146,23 @@ class FiniteAudioEncoder(nn.Module): stability = xi.dtype == torch.bfloat16 x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1) - x = x + self.pos_embedding + + if mode == "prom": + x = x + self.pos_embedding[0].unsqueeze(0) + elif mode == "resp": + x = x + self.pos_embedding[1].unsqueeze(0) + else: + x = x + self.pos_embedding + x = self.norm(x) if self.use_ffn: x = x + self.proj( x ) else: x = self.proj( x ) - if stability: + if self.level_weights is None: + x = x.sum(dim=1) + elif stability: weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1) x = (x.float() * weights).sum(dim=1).to(xi.dtype) else: @@ -313,6 +321,9 @@ class Base_V2(nn.Module): per_level_normalization = config.experimental.per_level_normalization if config is not None else True audio_decoder_ffn_expansion_size = config.experimental.audio_decoder_ffn_expansion_size if config is not None else 2 audio_encoder_ffn_expansion_size = config.experimental.audio_encoder_ffn_expansion_size if config is not None else 2 + use_audio_encoder_ffn = config.experimental.use_audio_encoder_ffn if config is not None else True + use_audio_encoder_norm = config.experimental.use_audio_encoder_norm if config is not None else True + use_audio_encoder_level_weights = config.experimental.use_audio_encoder_level_weights if config is not None else True use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True use_sliding_attention_mask = config.experimental.use_sliding_attention_mask if config is not None else True parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0 @@ -430,20 +441,33 @@ class Base_V2(nn.Module): n_tokens=n_audio_tokens + 2, # stop + masked token n_levels=self.n_resp_levels, token_dim=d_model, + monolithic=True, d_ffn=audio_encoder_ffn_expansion_size, + use_ln=use_audio_encoder_norm, + use_ffn=use_audio_encoder_ffn, + use_level_weights=use_audio_encoder_level_weights, ) + + self.proms_emb = lambda *args, **kwargs: self.audio_emb( *args, **kwargs, mode="prom" ) + self.resps_emb = lambda *args, **kwargs: self.audio_emb( *args, **kwargs, mode="resp" ) else: self.proms_emb = AudioEncoder( n_tokens=n_audio_tokens, n_levels=self.n_resp_levels, token_dim=d_model, d_ffn=audio_encoder_ffn_expansion_size, + use_ln=use_audio_encoder_norm, + use_ffn=use_audio_encoder_ffn, + use_level_weights=use_audio_encoder_level_weights, ) self.resps_emb = AudioEncoder( n_tokens=n_audio_tokens + 2, # stop + masked token n_levels=self.n_resp_levels, token_dim=d_model, d_ffn=audio_encoder_ffn_expansion_size, + use_ln=use_audio_encoder_norm, + use_ffn=use_audio_encoder_ffn, + use_level_weights=use_audio_encoder_level_weights, ) self.audio_decoder = AudioDecoder( @@ -721,9 +745,6 @@ class Base_V2(nn.Module): if isinstance(input, str): return self.tasks_emb( torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) ) - if self.audio_emb is not None: - return self.audio_emb( input ) - return self.proms_emb( input ) x_list = [] @@ -774,10 +795,7 @@ class Base_V2(nn.Module): elif name == "tone" and self.tones_emb is not None: embedding = self.tones_emb( input ) elif name == "resp": - if self.audio_emb is not None: - embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) - else: - embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) + embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) elif name == "timestep" and self.time_emb is not None: embedding = self.time_emb( input ) elif name == "len" and self.len_emb is not None: