From d1d91295b3c920c7a762277dfbbb029e8e94cb28 Mon Sep 17 00:00:00 2001
From: mrq <mrq@ecker.tech>
Date: Fri, 21 Mar 2025 19:05:49 -0500
Subject: [PATCH] add segmented sliding attention, also found a bug with
 prom-less segments in the attention mask generation.........

---
 vall_e/config.py            |  2 ++
 vall_e/models/arch/llama.py | 39 +++++++++++++++++++++++++++++++------
 vall_e/models/base_v2.py    | 15 ++++++++++----
 3 files changed, 46 insertions(+), 10 deletions(-)

diff --git a/vall_e/config.py b/vall_e/config.py
index 743a03e..3c07c8a 100755
--- a/vall_e/config.py
+++ b/vall_e/config.py
@@ -313,6 +313,8 @@ class ModelExperimentalSettings:
 	# list of floats to manually set
 	use_segmented_attention_mask: bool = False # instead of naively using a full attention mask, use one where each segment cannot attend after itself
 	# this is a flag since I am cautious
+	use_sliding_attention_mask: bool = False # when used with above, applies a sliding mask within the current segment
+	# this is a flag since I am cautious
 	use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck
 
 	# performs token dropout to compensate for errors
diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py
index 821fdc4..89bba6e 100644
--- a/vall_e/models/arch/llama.py
+++ b/vall_e/models/arch/llama.py
@@ -600,12 +600,28 @@ class Model(LlamaPreTrainedModel):
 		inverted_mask = 1.0 - expanded_mask
 		return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
 
+	def _apply_sliding_window(self, mask, start_idx, end_idx, window_size):
+		window_size = int(window_size // 2) # ick
+
+		for i in range(start_idx, end_idx):
+			if not window_size:
+				break
+
+			window_left = max(start_idx, i - window_size)
+			window_right = min(end_idx, i + window_size + 1)
+
+			mask[..., i, start_idx:window_left] = 0.0
+			mask[..., i, window_right:end_idx] = 0.0
+
+		return mask
+
 	# some funky segmented-attention mask because my gut says to do this
 	def _update_segmented_mask(
 		self,
 		attention_mask,
 		inputs_embeds,
 		aux_lens, # (bsz, lens), where [batch_index, 0] = text_len, and [batch_index, 1] = prom_len
+		window_sizes = None, # (bsz, lens), same as above
 		past_key_values_length=0,
 	):
 		# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
@@ -621,17 +637,28 @@ class Model(LlamaPreTrainedModel):
 		)
 
 		for batch_index, aux_len in enumerate( aux_lens ):
-			text_start, text_end = 0, aux_len[0]
+			window_size = window_sizes[batch_index] if window_sizes is not None else None
+			text_len = aux_len[0]
+			prom_len = aux_len[1]
+			output_len = aux_len[2]
 			
-			prom_start, prom_end = text_end, text_end + aux_len[1]
-			output_start, output_end = prom_end, prom_end + aux_len[2]
+			text_window = window_size[0] if window_size is not None else 0
+			prom_window = window_size[1] if window_size is not None else 0
+			output_window = window_size[2] if window_size is not None else 0
+			
+			text_start, text_end = 0, text_len
+			prom_start, prom_end = text_end, text_end + prom_len
+			output_start, output_end = prom_end, prom_end + output_len
 
-			if aux_len[0]:
+			if text_len:
 				expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0
-			if aux_len[1]:
+				expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], text_start, text_end, text_window )
+			if prom_len:
 				expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0
-			if aux_len[2]:
+				expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], prom_start, prom_end, prom_window )
+			if output_len:
 				expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = 1.0
+				expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], output_start, output_end, output_window )
 
 		# apply the original attention mask
 		expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py
index 1c5d6b5..258205f 100644
--- a/vall_e/models/base_v2.py
+++ b/vall_e/models/base_v2.py
@@ -300,6 +300,7 @@ class Base_V2(nn.Module):
 		logit_normalization = config.experimental.logit_normalization if config is not None else 0
 		per_level_normalization = config.experimental.per_level_normalization if config is not None else True
 		use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
+		use_sliding_attention_mask = config.experimental.use_sliding_attention_mask if config is not None else True
 		parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0
 
 		n_vocab = 256
@@ -392,6 +393,7 @@ class Base_V2(nn.Module):
 		self.len_loss_factor = len_loss_factor
 		self.logit_normalization = False # this actually kills the model's demasking capabilities
 		self.use_segmented_attention_mask = use_segmented_attention_mask
+		self.use_sliding_attention_mask = use_sliding_attention_mask
 		self.parallel_attention_mask_dropout = parallel_attention_mask_dropout
 		
 		self.sep = nn.Parameter(torch.randn(d_model))
@@ -1130,23 +1132,28 @@ class Base_V2(nn.Module):
 
 		# create special masks
 		# to-do, create it if mixed (although I expect this model to be purely non-causal)
-		aux_lens = torch.tensor([[2, 2, 0]] * batch_size, device=x.device, dtype=torch.int32)
+
+		text_window = 32 if self.use_sliding_attention_mask else 0
+		audio_window = self.audio_frames_per_second // 2 if self.use_sliding_attention_mask else 0
+
+		aux_lens = [[2, 0, 0]] * batch_size
+		aux_windows = [[text_window, audio_window, audio_window]] * batch_size
 		# fill aux lens
 		for batch_index, batch_input in enumerate( inputs ):
 			for name, input in batch_input:
 				if name in ["phn", "text"]:
-					aux_lens[batch_index][0] = input.shape[0]
+					aux_lens[batch_index][0] = input.shape[0] + 1
 				elif name == "lang":
 					aux_lens[batch_index][0] += 2
 				elif name == "prom":
-					aux_lens[batch_index][1] = input.shape[0]
+					aux_lens[batch_index][1] = input.shape[0] + 1
 				elif name == "tone":
 					aux_lens[batch_index][1] += 2
 				elif name == "resp":
 					aux_lens[batch_index][2] = input.shape[0]
 
 		if self.use_segmented_attention_mask and not any(is_causal):
-			mask = self.model._update_segmented_mask( mask, x, aux_lens )
+			mask = self.model._update_segmented_mask( mask, x, aux_lens, window_sizes=aux_windows )
 
 		output = self._forward(
 			inputs=x,