more tweaks to the new implementation (properly trim the len stuff to save some params, decoder to d_ffn expansion to 2 to maybe also make it faster, etc.)
This commit is contained in:
parent
9a8a8e3195
commit
5479d2eacc
|
@ -299,9 +299,11 @@ class ModelExperimentalSettings:
|
||||||
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
||||||
len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16
|
len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16
|
||||||
parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking
|
parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking
|
||||||
|
layer_dropout_p: float = 0.0 # performs layer dropout, which I readded because it might actually help since the reference model had this at 0.1
|
||||||
|
codebook_dropout_p: float = 0.0 # perform codebook dropout, which I added as an explicit feature since it seems the reference model had this in a way
|
||||||
|
# although I don't know how I could easily implement this for the new implementation
|
||||||
|
|
||||||
#
|
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298 (this actually is very bad)
|
||||||
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
|
|
||||||
per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder
|
per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder
|
||||||
audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training
|
audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training
|
||||||
# "auto" will pick best for codec
|
# "auto" will pick best for codec
|
||||||
|
@ -313,13 +315,14 @@ class ModelExperimentalSettings:
|
||||||
# this is a flag since I am cautious
|
# this is a flag since I am cautious
|
||||||
use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck
|
use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck
|
||||||
|
|
||||||
# these technically should be as hyperparameters
|
|
||||||
# performs token dropout to compensate for errors
|
# performs token dropout to compensate for errors
|
||||||
|
# currently unused, since this might be the wrong way to go about it
|
||||||
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
|
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
|
||||||
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
|
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
|
||||||
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
|
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
|
||||||
# these technically should be as hyperparameters
|
|
||||||
# classifier-free guidance training settings
|
# classifier-free guidance training settings
|
||||||
|
# too large might actually be a problem
|
||||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||||
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
|
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
|
||||||
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
|
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
|
||||||
|
|
|
@ -25,6 +25,7 @@ class Config(BaseConfig):
|
||||||
attn_mode = "sdpa",
|
attn_mode = "sdpa",
|
||||||
output_norm = True,
|
output_norm = True,
|
||||||
causal = True,
|
causal = True,
|
||||||
|
layer_dropout = 0.0,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
@ -32,6 +33,7 @@ class Config(BaseConfig):
|
||||||
self.attn_mode = attn_mode
|
self.attn_mode = attn_mode
|
||||||
self.output_norm = output_norm
|
self.output_norm = output_norm
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
|
self.layer_dropout = layer_dropout
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
@ -156,18 +158,10 @@ class Attention(nn.Module):
|
||||||
elif self.attn_mode == "sdpa":
|
elif self.attn_mode == "sdpa":
|
||||||
self.attn_mode = torch.nn.attention.SDPBackend.MATH
|
self.attn_mode = torch.nn.attention.SDPBackend.MATH
|
||||||
|
|
||||||
self.q_proj = nn.Linear(
|
self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias )
|
||||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias )
|
||||||
)
|
self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias )
|
||||||
self.k_proj = nn.Linear(
|
self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias )
|
||||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
||||||
)
|
|
||||||
self.v_proj = nn.Linear(
|
|
||||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
||||||
)
|
|
||||||
self.o_proj = nn.Linear(
|
|
||||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
||||||
)
|
|
||||||
|
|
||||||
# extracts inputs from a batch based on requested causality
|
# extracts inputs from a batch based on requested causality
|
||||||
def split_forward(
|
def split_forward(
|
||||||
|
@ -576,6 +570,10 @@ class Model(LlamaPreTrainedModel):
|
||||||
self.rotary_emb = RotaryEmbedding(config=config)
|
self.rotary_emb = RotaryEmbedding(config=config)
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# setup layer dropout LUT
|
||||||
|
LN_2 = 0.69314718056
|
||||||
|
self.layer_dropouts = [ (math.exp((l * LN_2) / (self.layers_n - 1)) - 1) * self.config.layer_dropout for l in range(self.layers_n) ]
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
@ -725,6 +723,9 @@ class Model(LlamaPreTrainedModel):
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
def dropout_layer( self, l ):
|
||||||
|
return random.random() < self.layer_dropouts[l] if self.training else False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
@ -829,6 +830,7 @@ class Model(LlamaPreTrainedModel):
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.dropout_layer( l ):
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
|
|
@ -89,20 +89,25 @@ class FiniteAudioEncoder(nn.Module):
|
||||||
n_tokens: int,
|
n_tokens: int,
|
||||||
n_levels: int,
|
n_levels: int,
|
||||||
token_dim: int,
|
token_dim: int,
|
||||||
use_ln: bool = True,
|
use_ln: bool = True, # whether to perform a post-embedding pre-norm or not (I'm not sure if this is redundant)
|
||||||
use_ffn: bool = True,
|
use_ffn: bool = True, # whether to employ a residual feed forward network or not
|
||||||
training: bool = True,
|
|
||||||
|
d_model: int = None,
|
||||||
|
d_ffn: int = 2, # feed forward expansion value
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if not d_model:
|
||||||
|
d_model = token_dim
|
||||||
|
|
||||||
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
|
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
|
||||||
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
|
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
|
||||||
self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
|
self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
|
||||||
self.proj = nn.Sequential(
|
self.proj = nn.Sequential(
|
||||||
nn.Linear(token_dim, token_dim * 2),
|
nn.Linear(token_dim, token_dim * d_ffn),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(token_dim * 2, token_dim),
|
nn.Linear(token_dim * d_ffn, d_model),
|
||||||
#nn.Dropout(0.1 if training else 0.0)
|
) if use_ffn else nn.Linear(token_dim, d_model)
|
||||||
) if use_ffn else nn.Linear(token_dim, token_dim)
|
|
||||||
|
|
||||||
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
|
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
|
||||||
|
|
||||||
|
@ -151,15 +156,17 @@ class FiniteAudioDecoder(nn.Module):
|
||||||
d_model: int,
|
d_model: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
n_levels: int,
|
n_levels: int,
|
||||||
d_ffn: int = 4,
|
|
||||||
use_ln: bool = True,
|
d_ffn: int = 2, # feed forward expansion value
|
||||||
shared_levels: bool = False,
|
use_ln: bool = True, # perform layer normalization here
|
||||||
training: bool = False,
|
use_ffn: bool = True, # use a feed forward network post-norm pre-classifier
|
||||||
|
shared_levels: bool = False, # whether to have one set of weights for all codebook levels, or separate weights for each layer
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_levels = n_levels
|
self.n_levels = n_levels
|
||||||
self.shared_levels = shared_levels
|
self.shared_levels = shared_levels
|
||||||
|
|
||||||
|
if use_ffn:
|
||||||
if not shared_levels:
|
if not shared_levels:
|
||||||
self.head = nn.ModuleList([nn.Sequential(
|
self.head = nn.ModuleList([nn.Sequential(
|
||||||
# ln
|
# ln
|
||||||
|
@ -180,6 +187,19 @@ class FiniteAudioDecoder(nn.Module):
|
||||||
# head
|
# head
|
||||||
nn.Linear(d_model, vocab_size * n_levels)
|
nn.Linear(d_model, vocab_size * n_levels)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if not shared_levels:
|
||||||
|
self.head = nn.ModuleList([nn.Sequential(
|
||||||
|
# ln
|
||||||
|
(nn.LayerNorm(d_model) if use_ln else nn.Identity()),
|
||||||
|
# head
|
||||||
|
nn.Linear(d_model, vocab_size)
|
||||||
|
) for _ in range(n_levels)])
|
||||||
|
else:
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
# head
|
||||||
|
nn.Linear(d_model, vocab_size * n_levels)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
batch_size, seq_len, _ = x.shape
|
batch_size, seq_len, _ = x.shape
|
||||||
|
@ -376,50 +396,37 @@ class Base_V2(nn.Module):
|
||||||
self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None
|
self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None
|
||||||
self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None
|
self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None
|
||||||
self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None
|
self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None
|
||||||
self.len_emb = ml.Embedding(11, d_model) # unused
|
self.len_emb = None # ml.Embedding(11, d_model)
|
||||||
|
|
||||||
self.audio_emb = None
|
self.audio_emb = None
|
||||||
self.proms_emb = None
|
self.proms_emb = None
|
||||||
self.resps_emb = None
|
self.resps_emb = None
|
||||||
|
|
||||||
"""
|
|
||||||
if n_audio_tokens == 1000 or:
|
|
||||||
AudioEncoder = FiniteAudioEncoder
|
|
||||||
AudioDecoder = FiniteAudioDecoder
|
|
||||||
else:
|
|
||||||
AudioEncoder = ResidualAudioEncoder
|
|
||||||
AudioDecoder = ResidualAudioDecoder
|
|
||||||
"""
|
|
||||||
|
|
||||||
if monolithic_audio_encoder:
|
if monolithic_audio_encoder:
|
||||||
self.audio_emb = AudioEncoder(
|
self.audio_emb = AudioEncoder(
|
||||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
token_dim=d_model,
|
token_dim=d_model,
|
||||||
training=training,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.proms_emb = AudioEncoder(
|
self.proms_emb = AudioEncoder(
|
||||||
n_tokens=n_audio_tokens,
|
n_tokens=n_audio_tokens,
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
token_dim=d_model,
|
token_dim=d_model,
|
||||||
training=training,
|
|
||||||
)
|
)
|
||||||
self.resps_emb = AudioEncoder(
|
self.resps_emb = AudioEncoder(
|
||||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
token_dim=d_model,
|
token_dim=d_model,
|
||||||
training=training,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.audio_decoder = AudioDecoder(
|
self.audio_decoder = AudioDecoder(
|
||||||
d_model,
|
d_model,
|
||||||
(n_audio_tokens + 1),
|
(n_audio_tokens + 1),
|
||||||
self.n_resp_levels,
|
self.n_resp_levels,
|
||||||
training=training,
|
|
||||||
use_ln=per_level_normalization,
|
use_ln=per_level_normalization,
|
||||||
)
|
)
|
||||||
self.len_decoder = AuxDecoder( d_model, 11 ) # to-do: adjust this
|
self.len_decoder = AuxDecoder( d_model, 1 )
|
||||||
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
|
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
|
||||||
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
|
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user