fixed errant index error (although it makes me wonder if my segmented masking is still flawed)

This commit is contained in:
mrq 2025-03-21 23:41:34 -05:00
parent d1d91295b3
commit 02a8bcbe29

View File

@ -603,15 +603,18 @@ class Model(LlamaPreTrainedModel):
def _apply_sliding_window(self, mask, start_idx, end_idx, window_size):
window_size = int(window_size // 2) # ick
for i in range(start_idx, end_idx):
seq_len = mask.size(-1)
for i in range(start_idx, min(end_idx, seq_len)):
if not window_size:
break
window_left = max(start_idx, i - window_size)
window_right = min(end_idx, i + window_size + 1)
window_start = max(start_idx, i - window_size)
window_end = min(end_idx, i + window_size + 1)
mask[..., i, start_idx:window_left] = 0.0
mask[..., i, window_right:end_idx] = 0.0
if window_start > start_idx:
mask[..., i, start_idx:window_start] = 0
if window_end < end_idx:
mask[..., i, window_end:end_idx] = 0
return mask