fixed errant index error (although it makes me wonder if my segmented masking is still flawed)
This commit is contained in:
parent
d1d91295b3
commit
02a8bcbe29
|
@ -603,15 +603,18 @@ class Model(LlamaPreTrainedModel):
|
|||
def _apply_sliding_window(self, mask, start_idx, end_idx, window_size):
|
||||
window_size = int(window_size // 2) # ick
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
seq_len = mask.size(-1)
|
||||
for i in range(start_idx, min(end_idx, seq_len)):
|
||||
if not window_size:
|
||||
break
|
||||
|
||||
window_left = max(start_idx, i - window_size)
|
||||
window_right = min(end_idx, i + window_size + 1)
|
||||
window_start = max(start_idx, i - window_size)
|
||||
window_end = min(end_idx, i + window_size + 1)
|
||||
|
||||
mask[..., i, start_idx:window_left] = 0.0
|
||||
mask[..., i, window_right:end_idx] = 0.0
|
||||
if window_start > start_idx:
|
||||
mask[..., i, start_idx:window_start] = 0
|
||||
if window_end < end_idx:
|
||||
mask[..., i, window_end:end_idx] = 0
|
||||
|
||||
return mask
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user