diff --git a/vall_e/config.py b/vall_e/config.py index 743a03e..3c07c8a 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -313,6 +313,8 @@ class ModelExperimentalSettings: # list of floats to manually set use_segmented_attention_mask: bool = False # instead of naively using a full attention mask, use one where each segment cannot attend after itself # this is a flag since I am cautious + use_sliding_attention_mask: bool = False # when used with above, applies a sliding mask within the current segment + # this is a flag since I am cautious use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck # performs token dropout to compensate for errors diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 821fdc4..89bba6e 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -600,12 +600,28 @@ class Model(LlamaPreTrainedModel): inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min ) + def _apply_sliding_window(self, mask, start_idx, end_idx, window_size): + window_size = int(window_size // 2) # ick + + for i in range(start_idx, end_idx): + if not window_size: + break + + window_left = max(start_idx, i - window_size) + window_right = min(end_idx, i + window_size + 1) + + mask[..., i, start_idx:window_left] = 0.0 + mask[..., i, window_right:end_idx] = 0.0 + + return mask + # some funky segmented-attention mask because my gut says to do this def _update_segmented_mask( self, attention_mask, inputs_embeds, aux_lens, # (bsz, lens), where [batch_index, 0] = text_len, and [batch_index, 1] = prom_len + window_sizes = None, # (bsz, lens), same as above past_key_values_length=0, ): # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] @@ -621,17 +637,28 @@ class Model(LlamaPreTrainedModel): ) for batch_index, aux_len in enumerate( aux_lens ): - text_start, text_end = 0, aux_len[0] + window_size = window_sizes[batch_index] if window_sizes is not None else None + text_len = aux_len[0] + prom_len = aux_len[1] + output_len = aux_len[2] - prom_start, prom_end = text_end, text_end + aux_len[1] - output_start, output_end = prom_end, prom_end + aux_len[2] + text_window = window_size[0] if window_size is not None else 0 + prom_window = window_size[1] if window_size is not None else 0 + output_window = window_size[2] if window_size is not None else 0 + + text_start, text_end = 0, text_len + prom_start, prom_end = text_end, text_end + prom_len + output_start, output_end = prom_end, prom_end + output_len - if aux_len[0]: + if text_len: expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0 - if aux_len[1]: + expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], text_start, text_end, text_window ) + if prom_len: expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0 - if aux_len[2]: + expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], prom_start, prom_end, prom_window ) + if output_len: expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = 1.0 + expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], output_start, output_end, output_window ) # apply the original attention mask expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 1c5d6b5..258205f 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -300,6 +300,7 @@ class Base_V2(nn.Module): logit_normalization = config.experimental.logit_normalization if config is not None else 0 per_level_normalization = config.experimental.per_level_normalization if config is not None else True use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True + use_sliding_attention_mask = config.experimental.use_sliding_attention_mask if config is not None else True parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0 n_vocab = 256 @@ -392,6 +393,7 @@ class Base_V2(nn.Module): self.len_loss_factor = len_loss_factor self.logit_normalization = False # this actually kills the model's demasking capabilities self.use_segmented_attention_mask = use_segmented_attention_mask + self.use_sliding_attention_mask = use_sliding_attention_mask self.parallel_attention_mask_dropout = parallel_attention_mask_dropout self.sep = nn.Parameter(torch.randn(d_model)) @@ -1130,23 +1132,28 @@ class Base_V2(nn.Module): # create special masks # to-do, create it if mixed (although I expect this model to be purely non-causal) - aux_lens = torch.tensor([[2, 2, 0]] * batch_size, device=x.device, dtype=torch.int32) + + text_window = 32 if self.use_sliding_attention_mask else 0 + audio_window = self.audio_frames_per_second // 2 if self.use_sliding_attention_mask else 0 + + aux_lens = [[2, 0, 0]] * batch_size + aux_windows = [[text_window, audio_window, audio_window]] * batch_size # fill aux lens for batch_index, batch_input in enumerate( inputs ): for name, input in batch_input: if name in ["phn", "text"]: - aux_lens[batch_index][0] = input.shape[0] + aux_lens[batch_index][0] = input.shape[0] + 1 elif name == "lang": aux_lens[batch_index][0] += 2 elif name == "prom": - aux_lens[batch_index][1] = input.shape[0] + aux_lens[batch_index][1] = input.shape[0] + 1 elif name == "tone": aux_lens[batch_index][1] += 2 elif name == "resp": aux_lens[batch_index][2] = input.shape[0] if self.use_segmented_attention_mask and not any(is_causal): - mask = self.model._update_segmented_mask( mask, x, aux_lens ) + mask = self.model._update_segmented_mask( mask, x, aux_lens, window_sizes=aux_windows ) output = self._forward( inputs=x,