more settings bloat because there seems to be instability with the encoder as-is
This commit is contained in:
parent
f144389920
commit
814146a5e0
|
@ -317,6 +317,9 @@ class ModelExperimentalSettings:
|
|||
use_sliding_attention_mask: bool = False # when used with above, applies a sliding mask within the current segment
|
||||
# this is a flag since I am cautious
|
||||
use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck
|
||||
use_audio_encoder_level_weights: bool = True # flag to maintain backwards compat
|
||||
use_audio_encoder_ffn: bool = True #
|
||||
use_audio_encoder_norm: bool = True #
|
||||
audio_decoder_ffn_expansion_size: int = 2 # need to do something awful with this
|
||||
audio_encoder_ffn_expansion_size: int = 2 # need to do something awful with this
|
||||
|
||||
|
|
|
@ -89,8 +89,10 @@ class FiniteAudioEncoder(nn.Module):
|
|||
n_tokens: int,
|
||||
n_levels: int,
|
||||
token_dim: int,
|
||||
monolithic: bool = False,
|
||||
use_ln: bool = True, # whether to perform a post-embedding pre-norm or not (I'm not sure if this is redundant)
|
||||
use_ffn: bool = True, # whether to employ a residual feed forward network or not
|
||||
use_level_weights: bool = False,
|
||||
|
||||
d_model: int = None,
|
||||
d_ffn: int = 2, # feed forward expansion value
|
||||
|
@ -101,8 +103,12 @@ class FiniteAudioEncoder(nn.Module):
|
|||
d_model = token_dim
|
||||
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
|
||||
|
||||
# there needs to be some information when separating between the proms and the resps
|
||||
self.pos_embedding = nn.Parameter(torch.randn(2 if monolithic else 1, n_levels, token_dim) * 0.02)
|
||||
|
||||
self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
|
||||
|
||||
if use_ffn:
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(token_dim, token_dim * d_ffn),
|
||||
|
@ -114,16 +120,9 @@ class FiniteAudioEncoder(nn.Module):
|
|||
else:
|
||||
self.proj = nn.Identity()
|
||||
|
||||
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
|
||||
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels)) if use_level_weights else None
|
||||
self.use_ffn = use_ffn
|
||||
|
||||
# explicit initialization
|
||||
# this is actually BAD BAD BAD
|
||||
"""
|
||||
for emb in self.embs:
|
||||
torch.nn.init.normal_(emb.weight, mean=0.0, std=0.02)
|
||||
"""
|
||||
|
||||
if use_ffn:
|
||||
nn.init.xavier_uniform_(self.proj[0].weight)
|
||||
nn.init.xavier_uniform_(self.proj[2].weight)
|
||||
|
@ -134,7 +133,7 @@ class FiniteAudioEncoder(nn.Module):
|
|||
nn.init.xavier_uniform_(self.proj.weight)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None, stability = None ) -> Tensor:
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None, stability = None, mode = None ) -> Tensor:
|
||||
# empty
|
||||
if xi.shape[0] == 0:
|
||||
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
|
||||
|
@ -147,14 +146,23 @@ class FiniteAudioEncoder(nn.Module):
|
|||
stability = xi.dtype == torch.bfloat16
|
||||
|
||||
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
|
||||
x = x + self.pos_embedding
|
||||
|
||||
if mode == "prom":
|
||||
x = x + self.pos_embedding[0].unsqueeze(0)
|
||||
elif mode == "resp":
|
||||
x = x + self.pos_embedding[1].unsqueeze(0)
|
||||
else:
|
||||
x = x + self.pos_embedding
|
||||
|
||||
x = self.norm(x)
|
||||
if self.use_ffn:
|
||||
x = x + self.proj( x )
|
||||
else:
|
||||
x = self.proj( x )
|
||||
|
||||
if stability:
|
||||
if self.level_weights is None:
|
||||
x = x.sum(dim=1)
|
||||
elif stability:
|
||||
weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1)
|
||||
x = (x.float() * weights).sum(dim=1).to(xi.dtype)
|
||||
else:
|
||||
|
@ -313,6 +321,9 @@ class Base_V2(nn.Module):
|
|||
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
|
||||
audio_decoder_ffn_expansion_size = config.experimental.audio_decoder_ffn_expansion_size if config is not None else 2
|
||||
audio_encoder_ffn_expansion_size = config.experimental.audio_encoder_ffn_expansion_size if config is not None else 2
|
||||
use_audio_encoder_ffn = config.experimental.use_audio_encoder_ffn if config is not None else True
|
||||
use_audio_encoder_norm = config.experimental.use_audio_encoder_norm if config is not None else True
|
||||
use_audio_encoder_level_weights = config.experimental.use_audio_encoder_level_weights if config is not None else True
|
||||
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
|
||||
use_sliding_attention_mask = config.experimental.use_sliding_attention_mask if config is not None else True
|
||||
parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0
|
||||
|
@ -430,20 +441,33 @@ class Base_V2(nn.Module):
|
|||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
monolithic=True,
|
||||
d_ffn=audio_encoder_ffn_expansion_size,
|
||||
use_ln=use_audio_encoder_norm,
|
||||
use_ffn=use_audio_encoder_ffn,
|
||||
use_level_weights=use_audio_encoder_level_weights,
|
||||
)
|
||||
|
||||
self.proms_emb = lambda *args, **kwargs: self.audio_emb( *args, **kwargs, mode="prom" )
|
||||
self.resps_emb = lambda *args, **kwargs: self.audio_emb( *args, **kwargs, mode="resp" )
|
||||
else:
|
||||
self.proms_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens,
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
d_ffn=audio_encoder_ffn_expansion_size,
|
||||
use_ln=use_audio_encoder_norm,
|
||||
use_ffn=use_audio_encoder_ffn,
|
||||
use_level_weights=use_audio_encoder_level_weights,
|
||||
)
|
||||
self.resps_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
d_ffn=audio_encoder_ffn_expansion_size,
|
||||
use_ln=use_audio_encoder_norm,
|
||||
use_ffn=use_audio_encoder_ffn,
|
||||
use_level_weights=use_audio_encoder_level_weights,
|
||||
)
|
||||
|
||||
self.audio_decoder = AudioDecoder(
|
||||
|
@ -721,9 +745,6 @@ class Base_V2(nn.Module):
|
|||
if isinstance(input, str):
|
||||
return self.tasks_emb( torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) )
|
||||
|
||||
if self.audio_emb is not None:
|
||||
return self.audio_emb( input )
|
||||
|
||||
return self.proms_emb( input )
|
||||
|
||||
x_list = []
|
||||
|
@ -774,10 +795,7 @@ class Base_V2(nn.Module):
|
|||
elif name == "tone" and self.tones_emb is not None:
|
||||
embedding = self.tones_emb( input )
|
||||
elif name == "resp":
|
||||
if self.audio_emb is not None:
|
||||
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
|
||||
else:
|
||||
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
|
||||
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
|
||||
elif name == "timestep" and self.time_emb is not None:
|
||||
embedding = self.time_emb( input )
|
||||
elif name == "len" and self.len_emb is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user