the culprit was initializing the level_weights for killing newly trained models.............
This commit is contained in:
parent
6c6a34dd21
commit
f144389920
|
@ -317,6 +317,8 @@ class ModelExperimentalSettings:
|
|||
use_sliding_attention_mask: bool = False # when used with above, applies a sliding mask within the current segment
|
||||
# this is a flag since I am cautious
|
||||
use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck
|
||||
audio_decoder_ffn_expansion_size: int = 2 # need to do something awful with this
|
||||
audio_encoder_ffn_expansion_size: int = 2 # need to do something awful with this
|
||||
|
||||
# performs token dropout to compensate for errors
|
||||
# currently unused, since this might be the wrong way to go about it
|
||||
|
@ -329,6 +331,7 @@ class ModelExperimentalSettings:
|
|||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
|
||||
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
|
||||
lang_cond_dropout_p: float = 0.0 # probability to drop out language token during training
|
||||
|
||||
use_raw_text_p: float = 0.0 # probability to use raw text as the input prompt instead
|
||||
|
||||
|
@ -810,6 +813,8 @@ class Trainer:
|
|||
wandb_params: dict = field(default_factory=lambda: dict)
|
||||
|
||||
weight_dtype: str = "float16" # dtype to have the model under
|
||||
audio_device: str = "auto"
|
||||
decode_non_resp_audio: bool = True
|
||||
|
||||
amp: bool = False # automatic mixed precision
|
||||
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
|
||||
|
|
|
@ -76,10 +76,12 @@ class AR_NAR_V2(Base_V2):
|
|||
# RVQ levels to apply masking training on
|
||||
masking_train_rvq_levels = [0,self.n_resp_levels] # self.config.experimental.masking_train_rvq_levels
|
||||
|
||||
|
||||
# CFG
|
||||
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
||||
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
||||
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
|
||||
lang_cond_dropout_p = self.config.experimental.lang_cond_dropout_p if self.config is not None else 0.0
|
||||
use_raw_text_p = self.config.experimental.use_raw_text_p if self.config is not None else 0.0
|
||||
# rate to train RVQ level AR-ly or NAR-ly
|
||||
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
|
||||
|
@ -154,6 +156,9 @@ class AR_NAR_V2(Base_V2):
|
|||
if task == "len":
|
||||
quant_levels[i] = 0
|
||||
|
||||
if random.random() < lang_cond_dropout_p:
|
||||
lang_list[i] = None
|
||||
|
||||
# apply CFG (should probably only apply to NAR quant level 0)
|
||||
if task not in text_task + ["len"]:
|
||||
drop_text = False
|
||||
|
|
|
@ -115,12 +115,15 @@ class FiniteAudioEncoder(nn.Module):
|
|||
self.proj = nn.Identity()
|
||||
|
||||
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
|
||||
self.use_ffn = use_ffn
|
||||
|
||||
# explicit initialization
|
||||
# this is actually BAD BAD BAD
|
||||
"""
|
||||
for emb in self.embs:
|
||||
torch.nn.init.normal_(emb.weight, mean=0.0, std=0.02)
|
||||
"""
|
||||
|
||||
self.use_ffn = use_ffn
|
||||
if use_ffn:
|
||||
nn.init.xavier_uniform_(self.proj[0].weight)
|
||||
nn.init.xavier_uniform_(self.proj[2].weight)
|
||||
|
@ -131,7 +134,7 @@ class FiniteAudioEncoder(nn.Module):
|
|||
nn.init.xavier_uniform_(self.proj.weight)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None, stability = None ) -> Tensor:
|
||||
# empty
|
||||
if xi.shape[0] == 0:
|
||||
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
|
||||
|
@ -139,6 +142,10 @@ class FiniteAudioEncoder(nn.Module):
|
|||
if dropout_mask is not None:
|
||||
xi = _dropout_codes( xi, dropout_mask, dropout_token )
|
||||
|
||||
# some cronge
|
||||
if stability is None:
|
||||
stability = xi.dtype == torch.bfloat16
|
||||
|
||||
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
|
||||
x = x + self.pos_embedding
|
||||
x = self.norm(x)
|
||||
|
@ -147,8 +154,12 @@ class FiniteAudioEncoder(nn.Module):
|
|||
else:
|
||||
x = self.proj( x )
|
||||
|
||||
weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1)
|
||||
x = (x.float() * weights).sum(dim=1).to(xi.dtype)
|
||||
if stability:
|
||||
weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1)
|
||||
x = (x.float() * weights).sum(dim=1).to(xi.dtype)
|
||||
else:
|
||||
weights = F.softmax(self.level_weights, dim=0).view(1, -1, 1)
|
||||
x = (x * weights).sum(dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
@ -300,6 +311,8 @@ class Base_V2(nn.Module):
|
|||
len_loss_factor = config.experimental.len_loss_factor if config is not None else 0.00001
|
||||
logit_normalization = config.experimental.logit_normalization if config is not None else 0
|
||||
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
|
||||
audio_decoder_ffn_expansion_size = config.experimental.audio_decoder_ffn_expansion_size if config is not None else 2
|
||||
audio_encoder_ffn_expansion_size = config.experimental.audio_encoder_ffn_expansion_size if config is not None else 2
|
||||
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
|
||||
use_sliding_attention_mask = config.experimental.use_sliding_attention_mask if config is not None else True
|
||||
parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0
|
||||
|
@ -417,17 +430,20 @@ class Base_V2(nn.Module):
|
|||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
d_ffn=audio_encoder_ffn_expansion_size,
|
||||
)
|
||||
else:
|
||||
self.proms_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens,
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
d_ffn=audio_encoder_ffn_expansion_size,
|
||||
)
|
||||
self.resps_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
d_ffn=audio_encoder_ffn_expansion_size,
|
||||
)
|
||||
|
||||
self.audio_decoder = AudioDecoder(
|
||||
|
@ -435,6 +451,7 @@ class Base_V2(nn.Module):
|
|||
(n_audio_tokens + 1),
|
||||
self.n_resp_levels,
|
||||
use_ln=per_level_normalization,
|
||||
d_ffn=audio_decoder_ffn_expansion_size,
|
||||
)
|
||||
self.len_decoder = AuxDecoder( d_model, 1 if not len_use_logits else (10 * 5) )
|
||||
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
|
||||
|
|
|
@ -144,13 +144,14 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
prom_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
||||
audio_device = cfg.device if cfg.trainer.audio_device == "auto" else cfg.trainer.audio_device
|
||||
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path, audio_device)
|
||||
|
||||
if ref is not None:
|
||||
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
||||
ref_audio, sr = qnt.decode_to_file(ref, ref_path, audio_device)
|
||||
|
||||
if prom is not None:
|
||||
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
||||
prom_audio, sr = qnt.decode_to_file(prom, prom_path, audio_device)
|
||||
|
||||
# naive loss calculation
|
||||
# to-do: find a better way to calculate this / a better metric
|
||||
|
|
Loading…
Reference in New Issue
Block a user