From 02a8bcbe29fc9f9c8746f49d5a3996d9169fe49b Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 21 Mar 2025 23:41:34 -0500 Subject: [PATCH] fixed errant index error (although it makes me wonder if my segmented masking is still flawed) --- vall_e/models/arch/llama.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 89bba6e..1240738 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -603,15 +603,18 @@ class Model(LlamaPreTrainedModel): def _apply_sliding_window(self, mask, start_idx, end_idx, window_size): window_size = int(window_size // 2) # ick - for i in range(start_idx, end_idx): + seq_len = mask.size(-1) + for i in range(start_idx, min(end_idx, seq_len)): if not window_size: break - window_left = max(start_idx, i - window_size) - window_right = min(end_idx, i + window_size + 1) + window_start = max(start_idx, i - window_size) + window_end = min(end_idx, i + window_size + 1) - mask[..., i, start_idx:window_left] = 0.0 - mask[..., i, window_right:end_idx] = 0.0 + if window_start > start_idx: + mask[..., i, start_idx:window_start] = 0 + if window_end < end_idx: + mask[..., i, window_end:end_idx] = 0 return mask