more settings bloat because there seems to be instability with the encoder as-is

This commit is contained in:
mrq 2025-04-12 12:53:44 -05:00
parent f144389920
commit 814146a5e0
2 changed files with 40 additions and 19 deletions

View File

@ -317,6 +317,9 @@ class ModelExperimentalSettings:
use_sliding_attention_mask: bool = False # when used with above, applies a sliding mask within the current segment
# this is a flag since I am cautious
use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck
use_audio_encoder_level_weights: bool = True # flag to maintain backwards compat
use_audio_encoder_ffn: bool = True #
use_audio_encoder_norm: bool = True #
audio_decoder_ffn_expansion_size: int = 2 # need to do something awful with this
audio_encoder_ffn_expansion_size: int = 2 # need to do something awful with this

View File

@ -89,8 +89,10 @@ class FiniteAudioEncoder(nn.Module):
n_tokens: int,
n_levels: int,
token_dim: int,
monolithic: bool = False,
use_ln: bool = True, # whether to perform a post-embedding pre-norm or not (I'm not sure if this is redundant)
use_ffn: bool = True, # whether to employ a residual feed forward network or not
use_level_weights: bool = False,
d_model: int = None,
d_ffn: int = 2, # feed forward expansion value
@ -101,8 +103,12 @@ class FiniteAudioEncoder(nn.Module):
d_model = token_dim
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
# there needs to be some information when separating between the proms and the resps
self.pos_embedding = nn.Parameter(torch.randn(2 if monolithic else 1, n_levels, token_dim) * 0.02)
self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
if use_ffn:
self.proj = nn.Sequential(
nn.Linear(token_dim, token_dim * d_ffn),
@ -114,16 +120,9 @@ class FiniteAudioEncoder(nn.Module):
else:
self.proj = nn.Identity()
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels)) if use_level_weights else None
self.use_ffn = use_ffn
# explicit initialization
# this is actually BAD BAD BAD
"""
for emb in self.embs:
torch.nn.init.normal_(emb.weight, mean=0.0, std=0.02)
"""
if use_ffn:
nn.init.xavier_uniform_(self.proj[0].weight)
nn.init.xavier_uniform_(self.proj[2].weight)
@ -134,7 +133,7 @@ class FiniteAudioEncoder(nn.Module):
nn.init.xavier_uniform_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None, stability = None ) -> Tensor:
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None, stability = None, mode = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
@ -147,14 +146,23 @@ class FiniteAudioEncoder(nn.Module):
stability = xi.dtype == torch.bfloat16
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
x = x + self.pos_embedding
if mode == "prom":
x = x + self.pos_embedding[0].unsqueeze(0)
elif mode == "resp":
x = x + self.pos_embedding[1].unsqueeze(0)
else:
x = x + self.pos_embedding
x = self.norm(x)
if self.use_ffn:
x = x + self.proj( x )
else:
x = self.proj( x )
if stability:
if self.level_weights is None:
x = x.sum(dim=1)
elif stability:
weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1)
x = (x.float() * weights).sum(dim=1).to(xi.dtype)
else:
@ -313,6 +321,9 @@ class Base_V2(nn.Module):
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
audio_decoder_ffn_expansion_size = config.experimental.audio_decoder_ffn_expansion_size if config is not None else 2
audio_encoder_ffn_expansion_size = config.experimental.audio_encoder_ffn_expansion_size if config is not None else 2
use_audio_encoder_ffn = config.experimental.use_audio_encoder_ffn if config is not None else True
use_audio_encoder_norm = config.experimental.use_audio_encoder_norm if config is not None else True
use_audio_encoder_level_weights = config.experimental.use_audio_encoder_level_weights if config is not None else True
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
use_sliding_attention_mask = config.experimental.use_sliding_attention_mask if config is not None else True
parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0
@ -430,20 +441,33 @@ class Base_V2(nn.Module):
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
monolithic=True,
d_ffn=audio_encoder_ffn_expansion_size,
use_ln=use_audio_encoder_norm,
use_ffn=use_audio_encoder_ffn,
use_level_weights=use_audio_encoder_level_weights,
)
self.proms_emb = lambda *args, **kwargs: self.audio_emb( *args, **kwargs, mode="prom" )
self.resps_emb = lambda *args, **kwargs: self.audio_emb( *args, **kwargs, mode="resp" )
else:
self.proms_emb = AudioEncoder(
n_tokens=n_audio_tokens,
n_levels=self.n_resp_levels,
token_dim=d_model,
d_ffn=audio_encoder_ffn_expansion_size,
use_ln=use_audio_encoder_norm,
use_ffn=use_audio_encoder_ffn,
use_level_weights=use_audio_encoder_level_weights,
)
self.resps_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
d_ffn=audio_encoder_ffn_expansion_size,
use_ln=use_audio_encoder_norm,
use_ffn=use_audio_encoder_ffn,
use_level_weights=use_audio_encoder_level_weights,
)
self.audio_decoder = AudioDecoder(
@ -721,9 +745,6 @@ class Base_V2(nn.Module):
if isinstance(input, str):
return self.tasks_emb( torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) )
if self.audio_emb is not None:
return self.audio_emb( input )
return self.proms_emb( input )
x_list = []
@ -774,10 +795,7 @@ class Base_V2(nn.Module):
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
if self.audio_emb is not None:
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
else:
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
elif name == "timestep" and self.time_emb is not None:
embedding = self.time_emb( input )
elif name == "len" and self.len_emb is not None: