diff --git a/vall_e/config.py b/vall_e/config.py index 6430ab7..b4fbd38 100644 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -317,6 +317,8 @@ class ModelExperimentalSettings: use_sliding_attention_mask: bool = False # when used with above, applies a sliding mask within the current segment # this is a flag since I am cautious use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck + audio_decoder_ffn_expansion_size: int = 2 # need to do something awful with this + audio_encoder_ffn_expansion_size: int = 2 # need to do something awful with this # performs token dropout to compensate for errors # currently unused, since this might be the wrong way to go about it @@ -329,6 +331,7 @@ class ModelExperimentalSettings: cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training + lang_cond_dropout_p: float = 0.0 # probability to drop out language token during training use_raw_text_p: float = 0.0 # probability to use raw text as the input prompt instead @@ -810,6 +813,8 @@ class Trainer: wandb_params: dict = field(default_factory=lambda: dict) weight_dtype: str = "float16" # dtype to have the model under + audio_device: str = "auto" + decode_non_resp_audio: bool = True amp: bool = False # automatic mixed precision ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index 17d829e..99baf21 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -76,10 +76,12 @@ class AR_NAR_V2(Base_V2): # RVQ levels to apply masking training on masking_train_rvq_levels = [0,self.n_resp_levels] # self.config.experimental.masking_train_rvq_levels + # CFG cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0 cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0 cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0 + lang_cond_dropout_p = self.config.experimental.lang_cond_dropout_p if self.config is not None else 0.0 use_raw_text_p = self.config.experimental.use_raw_text_p if self.config is not None else 0.0 # rate to train RVQ level AR-ly or NAR-ly masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5 @@ -154,6 +156,9 @@ class AR_NAR_V2(Base_V2): if task == "len": quant_levels[i] = 0 + if random.random() < lang_cond_dropout_p: + lang_list[i] = None + # apply CFG (should probably only apply to NAR quant level 0) if task not in text_task + ["len"]: drop_text = False diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index c5e9759..7b3611f 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -115,12 +115,15 @@ class FiniteAudioEncoder(nn.Module): self.proj = nn.Identity() self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels)) + self.use_ffn = use_ffn # explicit initialization + # this is actually BAD BAD BAD + """ for emb in self.embs: torch.nn.init.normal_(emb.weight, mean=0.0, std=0.02) + """ - self.use_ffn = use_ffn if use_ffn: nn.init.xavier_uniform_(self.proj[0].weight) nn.init.xavier_uniform_(self.proj[2].weight) @@ -131,7 +134,7 @@ class FiniteAudioEncoder(nn.Module): nn.init.xavier_uniform_(self.proj.weight) nn.init.zeros_(self.proj.bias) - def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: + def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None, stability = None ) -> Tensor: # empty if xi.shape[0] == 0: dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0] @@ -139,6 +142,10 @@ class FiniteAudioEncoder(nn.Module): if dropout_mask is not None: xi = _dropout_codes( xi, dropout_mask, dropout_token ) + # some cronge + if stability is None: + stability = xi.dtype == torch.bfloat16 + x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1) x = x + self.pos_embedding x = self.norm(x) @@ -147,8 +154,12 @@ class FiniteAudioEncoder(nn.Module): else: x = self.proj( x ) - weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1) - x = (x.float() * weights).sum(dim=1).to(xi.dtype) + if stability: + weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1) + x = (x.float() * weights).sum(dim=1).to(xi.dtype) + else: + weights = F.softmax(self.level_weights, dim=0).view(1, -1, 1) + x = (x * weights).sum(dim=1) return x @@ -300,6 +311,8 @@ class Base_V2(nn.Module): len_loss_factor = config.experimental.len_loss_factor if config is not None else 0.00001 logit_normalization = config.experimental.logit_normalization if config is not None else 0 per_level_normalization = config.experimental.per_level_normalization if config is not None else True + audio_decoder_ffn_expansion_size = config.experimental.audio_decoder_ffn_expansion_size if config is not None else 2 + audio_encoder_ffn_expansion_size = config.experimental.audio_encoder_ffn_expansion_size if config is not None else 2 use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True use_sliding_attention_mask = config.experimental.use_sliding_attention_mask if config is not None else True parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0 @@ -417,17 +430,20 @@ class Base_V2(nn.Module): n_tokens=n_audio_tokens + 2, # stop + masked token n_levels=self.n_resp_levels, token_dim=d_model, + d_ffn=audio_encoder_ffn_expansion_size, ) else: self.proms_emb = AudioEncoder( n_tokens=n_audio_tokens, n_levels=self.n_resp_levels, token_dim=d_model, + d_ffn=audio_encoder_ffn_expansion_size, ) self.resps_emb = AudioEncoder( n_tokens=n_audio_tokens + 2, # stop + masked token n_levels=self.n_resp_levels, token_dim=d_model, + d_ffn=audio_encoder_ffn_expansion_size, ) self.audio_decoder = AudioDecoder( @@ -435,6 +451,7 @@ class Base_V2(nn.Module): (n_audio_tokens + 1), self.n_resp_levels, use_ln=per_level_normalization, + d_ffn=audio_decoder_ffn_expansion_size, ) self.len_decoder = AuxDecoder( d_model, 1 if not len_use_logits else (10 * 5) ) self.phn_decoder = AuxDecoder( d_model, n_phn_tokens ) diff --git a/vall_e/train.py b/vall_e/train.py index fd90d95..4ab4feb 100644 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -144,13 +144,14 @@ def run_eval(engines, eval_name, dl, args=None): ref_path.parent.mkdir(parents=True, exist_ok=True) prom_path.parent.mkdir(parents=True, exist_ok=True) - hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path) + audio_device = cfg.device if cfg.trainer.audio_device == "auto" else cfg.trainer.audio_device + hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path, audio_device) if ref is not None: - ref_audio, sr = qnt.decode_to_file(ref, ref_path) + ref_audio, sr = qnt.decode_to_file(ref, ref_path, audio_device) if prom is not None: - prom_audio, sr = qnt.decode_to_file(prom, prom_path) + prom_audio, sr = qnt.decode_to_file(prom, prom_path, audio_device) # naive loss calculation # to-do: find a better way to calculate this / a better metric