the culprit was initializing the level_weights for killing newly trained models.............

This commit is contained in:
mrq 2025-04-10 23:06:16 -05:00
parent 6c6a34dd21
commit f144389920
4 changed files with 35 additions and 7 deletions

View File

@ -317,6 +317,8 @@ class ModelExperimentalSettings:
use_sliding_attention_mask: bool = False # when used with above, applies a sliding mask within the current segment
# this is a flag since I am cautious
use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck
audio_decoder_ffn_expansion_size: int = 2 # need to do something awful with this
audio_encoder_ffn_expansion_size: int = 2 # need to do something awful with this
# performs token dropout to compensate for errors
# currently unused, since this might be the wrong way to go about it
@ -329,6 +331,7 @@ class ModelExperimentalSettings:
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
lang_cond_dropout_p: float = 0.0 # probability to drop out language token during training
use_raw_text_p: float = 0.0 # probability to use raw text as the input prompt instead
@ -810,6 +813,8 @@ class Trainer:
wandb_params: dict = field(default_factory=lambda: dict)
weight_dtype: str = "float16" # dtype to have the model under
audio_device: str = "auto"
decode_non_resp_audio: bool = True
amp: bool = False # automatic mixed precision
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested

View File

@ -76,10 +76,12 @@ class AR_NAR_V2(Base_V2):
# RVQ levels to apply masking training on
masking_train_rvq_levels = [0,self.n_resp_levels] # self.config.experimental.masking_train_rvq_levels
# CFG
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
lang_cond_dropout_p = self.config.experimental.lang_cond_dropout_p if self.config is not None else 0.0
use_raw_text_p = self.config.experimental.use_raw_text_p if self.config is not None else 0.0
# rate to train RVQ level AR-ly or NAR-ly
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
@ -154,6 +156,9 @@ class AR_NAR_V2(Base_V2):
if task == "len":
quant_levels[i] = 0
if random.random() < lang_cond_dropout_p:
lang_list[i] = None
# apply CFG (should probably only apply to NAR quant level 0)
if task not in text_task + ["len"]:
drop_text = False

View File

@ -115,12 +115,15 @@ class FiniteAudioEncoder(nn.Module):
self.proj = nn.Identity()
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
self.use_ffn = use_ffn
# explicit initialization
# this is actually BAD BAD BAD
"""
for emb in self.embs:
torch.nn.init.normal_(emb.weight, mean=0.0, std=0.02)
"""
self.use_ffn = use_ffn
if use_ffn:
nn.init.xavier_uniform_(self.proj[0].weight)
nn.init.xavier_uniform_(self.proj[2].weight)
@ -131,7 +134,7 @@ class FiniteAudioEncoder(nn.Module):
nn.init.xavier_uniform_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None, stability = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
@ -139,6 +142,10 @@ class FiniteAudioEncoder(nn.Module):
if dropout_mask is not None:
xi = _dropout_codes( xi, dropout_mask, dropout_token )
# some cronge
if stability is None:
stability = xi.dtype == torch.bfloat16
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
x = x + self.pos_embedding
x = self.norm(x)
@ -147,8 +154,12 @@ class FiniteAudioEncoder(nn.Module):
else:
x = self.proj( x )
weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1)
x = (x.float() * weights).sum(dim=1).to(xi.dtype)
if stability:
weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1)
x = (x.float() * weights).sum(dim=1).to(xi.dtype)
else:
weights = F.softmax(self.level_weights, dim=0).view(1, -1, 1)
x = (x * weights).sum(dim=1)
return x
@ -300,6 +311,8 @@ class Base_V2(nn.Module):
len_loss_factor = config.experimental.len_loss_factor if config is not None else 0.00001
logit_normalization = config.experimental.logit_normalization if config is not None else 0
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
audio_decoder_ffn_expansion_size = config.experimental.audio_decoder_ffn_expansion_size if config is not None else 2
audio_encoder_ffn_expansion_size = config.experimental.audio_encoder_ffn_expansion_size if config is not None else 2
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
use_sliding_attention_mask = config.experimental.use_sliding_attention_mask if config is not None else True
parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0
@ -417,17 +430,20 @@ class Base_V2(nn.Module):
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
d_ffn=audio_encoder_ffn_expansion_size,
)
else:
self.proms_emb = AudioEncoder(
n_tokens=n_audio_tokens,
n_levels=self.n_resp_levels,
token_dim=d_model,
d_ffn=audio_encoder_ffn_expansion_size,
)
self.resps_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
d_ffn=audio_encoder_ffn_expansion_size,
)
self.audio_decoder = AudioDecoder(
@ -435,6 +451,7 @@ class Base_V2(nn.Module):
(n_audio_tokens + 1),
self.n_resp_levels,
use_ln=per_level_normalization,
d_ffn=audio_decoder_ffn_expansion_size,
)
self.len_decoder = AuxDecoder( d_model, 1 if not len_use_logits else (10 * 5) )
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )

View File

@ -144,13 +144,14 @@ def run_eval(engines, eval_name, dl, args=None):
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
audio_device = cfg.device if cfg.trainer.audio_device == "auto" else cfg.trainer.audio_device
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path, audio_device)
if ref is not None:
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
ref_audio, sr = qnt.decode_to_file(ref, ref_path, audio_device)
if prom is not None:
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
prom_audio, sr = qnt.decode_to_file(prom, prom_path, audio_device)
# naive loss calculation
# to-do: find a better way to calculate this / a better metric