more tweaks to the new implementation (properly trim the len stuff to save some params, decoder to d_ffn expansion to 2 to maybe also make it faster, etc.)

This commit is contained in:
mrq 2025-03-18 19:34:37 -05:00
parent 9a8a8e3195
commit 5479d2eacc
4 changed files with 79 additions and 67 deletions

View File

@ -299,9 +299,11 @@ class ModelExperimentalSettings:
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16
parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking
layer_dropout_p: float = 0.0 # performs layer dropout, which I readded because it might actually help since the reference model had this at 0.1
codebook_dropout_p: float = 0.0 # perform codebook dropout, which I added as an explicit feature since it seems the reference model had this in a way
# although I don't know how I could easily implement this for the new implementation
#
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298 (this actually is very bad)
per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder
audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training
# "auto" will pick best for codec
@ -313,16 +315,17 @@ class ModelExperimentalSettings:
# this is a flag since I am cautious
use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck
# these technically should be as hyperparameters
# performs token dropout to compensate for errors
# currently unused, since this might be the wrong way to go about it
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
# these technically should be as hyperparameters
# classifier-free guidance training settings
# too large might actually be a problem
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
use_raw_text_p: float = 0.0 # probability to use raw text as the input prompt instead

View File

@ -25,6 +25,7 @@ class Config(BaseConfig):
attn_mode = "sdpa",
output_norm = True,
causal = True,
layer_dropout = 0.0,
*args, **kwargs
):
super().__init__(*args, **kwargs)
@ -32,6 +33,7 @@ class Config(BaseConfig):
self.attn_mode = attn_mode
self.output_norm = output_norm
self.causal = causal
self.layer_dropout = layer_dropout
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
@ -156,18 +158,10 @@ class Attention(nn.Module):
elif self.attn_mode == "sdpa":
self.attn_mode = torch.nn.attention.SDPBackend.MATH
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias )
self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias )
self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias )
self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias )
# extracts inputs from a batch based on requested causality
def split_forward(
@ -576,6 +570,10 @@ class Model(LlamaPreTrainedModel):
self.rotary_emb = RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# setup layer dropout LUT
LN_2 = 0.69314718056
self.layer_dropouts = [ (math.exp((l * LN_2) / (self.layers_n - 1)) - 1) * self.config.layer_dropout for l in range(self.layers_n) ]
# Initialize weights and apply final processing
self.post_init()
@ -725,6 +723,9 @@ class Model(LlamaPreTrainedModel):
return causal_mask
def dropout_layer( self, l ):
return random.random() < self.layer_dropouts[l] if self.training else False
def forward(
self,
input_ids: torch.LongTensor = None,
@ -829,10 +830,11 @@ class Model(LlamaPreTrainedModel):
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if not self.dropout_layer( l ):
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)

View File

@ -89,20 +89,25 @@ class FiniteAudioEncoder(nn.Module):
n_tokens: int,
n_levels: int,
token_dim: int,
use_ln: bool = True,
use_ffn: bool = True,
training: bool = True,
use_ln: bool = True, # whether to perform a post-embedding pre-norm or not (I'm not sure if this is redundant)
use_ffn: bool = True, # whether to employ a residual feed forward network or not
d_model: int = None,
d_ffn: int = 2, # feed forward expansion value
):
super().__init__()
if not d_model:
d_model = token_dim
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
self.proj = nn.Sequential(
nn.Linear(token_dim, token_dim * 2),
nn.Linear(token_dim, token_dim * d_ffn),
nn.GELU(),
nn.Linear(token_dim * 2, token_dim),
#nn.Dropout(0.1 if training else 0.0)
) if use_ffn else nn.Linear(token_dim, token_dim)
nn.Linear(token_dim * d_ffn, d_model),
) if use_ffn else nn.Linear(token_dim, d_model)
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
@ -151,35 +156,50 @@ class FiniteAudioDecoder(nn.Module):
d_model: int,
vocab_size: int,
n_levels: int,
d_ffn: int = 4,
use_ln: bool = True,
shared_levels: bool = False,
training: bool = False,
d_ffn: int = 2, # feed forward expansion value
use_ln: bool = True, # perform layer normalization here
use_ffn: bool = True, # use a feed forward network post-norm pre-classifier
shared_levels: bool = False, # whether to have one set of weights for all codebook levels, or separate weights for each layer
):
super().__init__()
self.n_levels = n_levels
self.shared_levels = shared_levels
if not shared_levels:
self.head = nn.ModuleList([nn.Sequential(
# ln
(nn.LayerNorm(d_model) if use_ln else nn.Identity()),
# ffn
nn.Linear(d_model, d_ffn * d_model),
nn.GELU(),
nn.Linear(d_ffn * d_model, d_model),
# head
nn.Linear(d_model, vocab_size)
) for _ in range(n_levels)])
if use_ffn:
if not shared_levels:
self.head = nn.ModuleList([nn.Sequential(
# ln
(nn.LayerNorm(d_model) if use_ln else nn.Identity()),
# ffn
nn.Linear(d_model, d_ffn * d_model),
nn.GELU(),
nn.Linear(d_ffn * d_model, d_model),
# head
nn.Linear(d_model, vocab_size)
) for _ in range(n_levels)])
else:
self.head = nn.Sequential(
# ffn
nn.Linear(d_model, d_ffn * d_model),
nn.GELU(),
nn.Linear(d_ffn * d_model, d_model),
# head
nn.Linear(d_model, vocab_size * n_levels)
)
else:
self.head = nn.Sequential(
# ffn
nn.Linear(d_model, d_ffn * d_model),
nn.GELU(),
nn.Linear(d_ffn * d_model, d_model),
# head
nn.Linear(d_model, vocab_size * n_levels)
)
if not shared_levels:
self.head = nn.ModuleList([nn.Sequential(
# ln
(nn.LayerNorm(d_model) if use_ln else nn.Identity()),
# head
nn.Linear(d_model, vocab_size)
) for _ in range(n_levels)])
else:
self.head = nn.Sequential(
# head
nn.Linear(d_model, vocab_size * n_levels)
)
def forward(self, x: Tensor) -> Tensor:
batch_size, seq_len, _ = x.shape
@ -376,50 +396,37 @@ class Base_V2(nn.Module):
self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None
self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None
self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None
self.len_emb = ml.Embedding(11, d_model) # unused
self.len_emb = None # ml.Embedding(11, d_model)
self.audio_emb = None
self.proms_emb = None
self.resps_emb = None
"""
if n_audio_tokens == 1000 or:
AudioEncoder = FiniteAudioEncoder
AudioDecoder = FiniteAudioDecoder
else:
AudioEncoder = ResidualAudioEncoder
AudioDecoder = ResidualAudioDecoder
"""
if monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
training=training,
)
else:
self.proms_emb = AudioEncoder(
n_tokens=n_audio_tokens,
n_levels=self.n_resp_levels,
token_dim=d_model,
training=training,
)
self.resps_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
training=training,
)
self.audio_decoder = AudioDecoder(
d_model,
(n_audio_tokens + 1),
self.n_resp_levels,
training=training,
use_ln=per_level_normalization,
)
self.len_decoder = AuxDecoder( d_model, 11 ) # to-do: adjust this
self.len_decoder = AuxDecoder( d_model, 1 )
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
self.text_decoder = AuxDecoder( d_model, n_text_tokens )

View File

@ -85,7 +85,7 @@ def gradio_wrapper(inputs):
return decorated
# returns a list of models, assuming the models are placed under ./training/ or ./models/ or ./data/models/
def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data/models/")] ):
def get_model_paths(paths=[Path("./training/"), Path("./models/"), Path("./data/models/")] ):
configs = []
for path in paths: