diff --git a/vall_e/config.py b/vall_e/config.py index 2143897..1565e52 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -277,6 +277,7 @@ class ModelExperimentalSettings: predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as # * NAR-demask would semi-doubly train for AR # * the model wouldn't also need to learn when to predict the token in place + audio_encoder_mode: str = "sum" # audio encoder mode for version >= 7, because I cannot make up my damn mind # these technically should be as hyperparameters # performs token dropout to compensate for errors @@ -737,6 +738,7 @@ class Trainer: activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training + detect_grad_anomaly: bool = False # torch.autograd.set_detect_anomaly check_for_oom: bool = True # checks for OOMs thrown during forward/backwards gc_mode: str | None = None # deprecated, but marks when to do GC diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 28e52c8..cbba2e9 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -88,20 +88,42 @@ class AudioEncoder(nn.Module): n_tokens: int, n_levels: int, token_dim: int, - enc_mode: str = "sum" + enc_mode: str = "sum", + l_weights: list[float] | None = None, ): super().__init__() self.enc_mode = enc_mode + + d_ffn = 4 + if not l_weights: + l_weights = [1 for _ in range(n_levels)] if enc_mode == "sum": self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)]) self.proj = None + self.weights = nn.Parameter(torch.tensor(l_weights)) elif enc_mode == "sub_interleave": self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)]) self.proj = None elif enc_mode == "interleave": self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)]) - self.proj = nn.Linear(8 * token_dim, 1 * token_dim) + #self.proj = nn.Linear(n_levels * token_dim, token_dim) + self.proj = nn.Sequential( + nn.Linear(n_levels * token_dim, d_ffn * token_dim), + nn.GELU(), + nn.Linear(d_ffn * token_dim, token_dim) + ) + elif enc_mode == "attn": + self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)]) + self.cross_attn = nn.MultiheadAttention(embed_dim=token_dim,num_heads=n_levels,dropout=0.1) + self.proj = nn.Sequential( + nn.Linear(n_levels * token_dim, d_ffn * token_dim), + nn.GELU(), + nn.Linear(d_ffn * token_dim, token_dim) + ) + + for emb in self.embs: + nn.init.normal_(emb.weight, mean=0.0, std=0.02) def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: # empty @@ -114,12 +136,26 @@ class AudioEncoder(nn.Module): # old way # in theory RVQ-based codecs should prefer this, but this doesn't yield good results if self.enc_mode == "sum": - x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]) + weights = F.softmax( self.weights, dim=0 ) + x = sum([ weights[l] * emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]) + # attention-based crunge + elif self.enc_mode == "attn": + x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1) + attn, _ = self.cross_attn( + x.permute(1, 0, 2), + x.permute(1, 0, 2), + x.permute(1, 0, 2), + ) + attn = attn.permute(1, 0, 2) + x = x + attn + x = x.view(x.shape[0], -1) + # x = attn.reshape(x.shape[0], -1) # encode by interleaving embeddings into one "token" # this "works" but I imagine it being excessive and doesn't seem to help the model all that much else: x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1) x = x.view(x.shape[0], -1) + if self.proj is not None: x = self.proj(x) @@ -207,6 +243,7 @@ class Base_V2(nn.Module): if not attention: attention = config.attention if config is not None else "auto" + n_resp_levels = config.resp_levels if config is not None else 8 attention_backend = attention unified_position_ids = config.experimental.unified_position_ids if config is not None else True noncausal_masks = config.experimental.noncausal_masks if config is not None else False @@ -218,6 +255,8 @@ class Base_V2(nn.Module): resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True predict_causally = config.experimental.predict_causally if config is not None else False monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False + audio_encoder_mode = config.experimental.audio_encoder_mode if config is not None else "sum" + audio_level_weights = [1.0 / (i + 1) for i in range(n_resp_levels)] # to-do: find the weights for FSQ n_vocab = 256 n_tasks = config.tasks if config is not None else 8 @@ -283,6 +322,7 @@ class Base_V2(nn.Module): self.masking_ratio = masking_ratio self.ignore_inputs_for_loss = ignore_inputs_for_loss self.noncausal_masks = noncausal_masks + self.audio_level_weights = audio_level_weights self.sep = nn.Parameter(torch.randn(d_model)) @@ -302,17 +342,23 @@ class Base_V2(nn.Module): n_tokens=n_audio_tokens + 2, # stop + masked token n_levels=self.n_resp_levels, token_dim=d_model, + enc_mode=audio_encoder_mode, + l_weights=audio_level_weights, ) else: self.proms_emb = AudioEncoder( n_tokens=n_audio_tokens, n_levels=self.n_resp_levels, token_dim=d_model, + enc_mode=audio_encoder_mode, + l_weights=audio_level_weights, ) self.resps_emb = AudioEncoder( n_tokens=n_audio_tokens + 2, # stop + masked token n_levels=self.n_resp_levels, token_dim=d_model, + enc_mode=audio_encoder_mode, + l_weights=audio_level_weights, ) self.audio_decoder = AudioDecoder( @@ -747,6 +793,7 @@ class Base_V2(nn.Module): device = logits[0].device batch_size = len(logits) classifier_levels = self.get_input( inputs, "classifier_level" ) + level_weights = self.audio_level_weights # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): @@ -755,7 +802,7 @@ class Base_V2(nn.Module): return input - def _calc_loss( logit, sequence, causal = True ): + def _calc_loss( logit, sequence, causal = True, level = None ): # filter tokens that exceed the vocab size sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence ) # drop if all tokens are ignored @@ -769,7 +816,8 @@ class Base_V2(nn.Module): sequence = sequence[..., l:] # ...predicts token n + 1 # flatten batch - if sequence.dim() > 1: + parallel = sequence.dim() > 1 + if parallel: logit = logit.reshape(-1, logit.shape[-1]) sequence = sequence.reshape(-1) @@ -777,7 +825,11 @@ class Base_V2(nn.Module): metrics = None if compute_hard_loss: - nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index ) + nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction='mean' if not parallel else 'none' ) * (level_weights[level] if level is not None and not parallel else 1) + + # manually weigh each level + if parallel: + nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_weights, device=device) if compute_acc: accuracy_metric = MulticlassAccuracy( @@ -875,9 +927,6 @@ class Base_V2(nn.Module): if logits[batch_index].dim() < 3: nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal ) - - if name == "resp": - name = f'{name}[{quant_level}]' elif not self.resp_parallel_training: # cringe way to deduce "requested" level level = quant_level @@ -885,24 +934,35 @@ class Base_V2(nn.Module): if classifier_level.endswith(f':{i}:{i}'): level = i break - """ + if name == "resp": name = f'{name}[{level}]' - """ + sequence = token if token.dim() <= 1 else token[:, level] - nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal ) + nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level ) else: - sequence = token.t() + sequence = token.t() nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) + + for level in enumerate(self.n_resp_levels): + loss_key = f'{name}[{level}].nll' + if loss_key not in loss: + loss[loss_key] = [] + loss[loss_key].append( nll[level] * loss_factor ) + + nll = None + + loss_key = f'{name}.nll' + acc_key = f'{name}.acc' if nll is not None: - if f'{name}.nll' not in loss: - loss[f'{name}.nll'] = [] - loss[f"{name}.nll"].append( nll * loss_factor ) + if loss_key not in loss: + loss[loss_key] = [] + loss[loss_key].append( nll * loss_factor ) if metrics is not None: - if f'{name}.acc' not in stats: - stats[f'{name}.acc'] = [] - stats[f"{name}.acc"].append( metrics ) + if acc_key not in stats: + stats[acc_key] = [] + stats[acc_key].append( metrics ) # add to list else: target.append( token ) @@ -922,7 +982,7 @@ class Base_V2(nn.Module): sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) - nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal ) + nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level ) else: nlls = [] accs = [] @@ -930,7 +990,7 @@ class Base_V2(nn.Module): for level, logit in enumerate( logits[batch_index] ): sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) - nll, metrics = _calc_loss( logit, sequence, causal ) + nll, metrics = _calc_loss( logit, sequence, causal, level ) if nll: nlls.append( nll ) diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 90e0854..09609e7 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -180,7 +180,9 @@ def train( break #batch = to_device(batch, torch.cuda.current_device()) - stats = engines.step(batch=batch, feeder=train_feeder) + with torch.autograd.set_detect_anomaly(cfg.trainer.detect_grad_anomaly): + stats = engines.step(batch=batch, feeder=train_feeder) + stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size()) elapsed_time = stats.get("elapsed_time", 0)