cleaned up parallel nar, i think it's slightly faster but even the smallest model is still slower than ar+nar-len-llama-8...

This commit is contained in:
mrq 2025-03-20 15:56:15 -05:00
parent 9a7458cf17
commit 8068f24e35
2 changed files with 34 additions and 60 deletions

View File

@ -263,7 +263,7 @@ class AR_NAR_V2(Base_V2):
# fill with masked tokens (even though they get masked anyways)
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
# fill scores
scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ]
scores = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.float32, device=device) for seq_len in len_list ]
quant_levels = [ level for _ in range(batch_size) ]
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
@ -280,10 +280,11 @@ class AR_NAR_V2(Base_V2):
# proportion of tokens to remask
remask_p = 1.0 / (max_steps * 2) if remasking else 0
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=0 ).indices for score, seq_len in zip(scores, len_list) ]
# normal masking
# mask off inputs
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
resps_list = [ torch.stack([resp[:, l].scatter(0, indices.t()[l], self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask
is_masked = [ resps == self.mask_token for resps in resps_list ]
# timestep inputs
@ -327,53 +328,20 @@ class AR_NAR_V2(Base_V2):
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] )
l_scores = []
l_resps_list = []
# cringe hack because we're able to sample multiple levels at once
for l in range(self.n_resp_levels):
# sample with sampler settings
filtered_sampled = super().sample(
logits=[ logit[l] for logit in logits ],
prev_list=[ resp[..., l] for resp in prev_list ],
quant_levels=quant_levels,
# sample with sampler settings
sampled = super().sample(
logits=logits,
prev_list=resps_list,
quant_levels=quant_levels,
temperature=sampling_temperature,
**sampling_kwargs,
)
temperature=sampling_temperature,
**sampling_kwargs,
)
# retrieves unfiltered logits
unfiltered_sampled = super().sample(
logits=[ logit[l] for logit in logits ],
prev_list=[ resp[..., l] for resp in prev_list ],
quant_levels=quant_levels,
temperature=0.0,
**sampling_kwargs,
)
# get sampled tokens
sampled_ids = filtered_sampled.ids
# keep unmasked tokens
l_resps_list.append([ torch.where( masked[..., l], input_ids, resps[..., l] ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ])
# get probability scores
l_scores.append([
# conjugate to have worse scoring tokens picked for topk
1.0 -
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
torch.where( masked[..., l], torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked[..., l].shape, device=device) )
# use unmodified logit scores for this, as it offers better stability
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
])
resps_list = []
scores = []
for batch_index in range(batch_size):
score = sum([ l_scores[level][batch_index] for level in range(self.n_resp_levels) ]) / self.n_resp_levels
resp = torch.stack([ l_resps_list[level][batch_index] for level in range(self.n_resp_levels) ], dim=-1)
scores.append( score )
resps_list.append( resp )
# update resps, filling in the masked tokens with the new tokens
resps_list = [ torch.where( masked, ids.t(), resps ).to(torch.int16) for masked, ids, resps in zip( is_masked, sampled.ids, resps_list ) ]
# update scores, filling in the
scores = [ 1.0 - torch.where( masked, scores.t(), 1 ) for masked, scores in zip( is_masked, sampled.scores ) ]
return resps_list

View File

@ -1236,8 +1236,9 @@ class Base_V2(nn.Module):
def sample(
self,
logits: list[Tensor], # logit scores
prev_list: list[Tensor] | None = None, # logit scores
logits: Tensor, # logit scores
prev_list: Tensor | None = None,
len_list: Tensor | None = None,
**sampling_kwargs,
):
# yikes
@ -1265,6 +1266,7 @@ class Base_V2(nn.Module):
attentions = sampling_kwargs.get("attentions", None)
batch_size = len( logits )
device = logits[0].device
if min_temperature < 0:
min_temperature = temperature
@ -1273,14 +1275,16 @@ class Base_V2(nn.Module):
entropy = None
if prev_list is not None:
seq_lens = map(len, prev_list)
logits = [ logit[-l:] for logit, l in zip(logits, seq_lens) ]
# (AR chunkwise) return the last chunkwise piece
seq_lens = [ prev.shape[0] for prev in prev_list ]
elif len_list is not None:
seq_lens = len_list
elif self.causal:
seq_lens = [ logit.shape[0] - self.causal_size for logit in logits ]
logits = [ logit[-self.causal_size:] for logit in logits ]
seq_lens = [ self.causal_size for _ in range( batch_size) ]
logits = [ logit[..., -l:, :] for l, logit in zip(seq_lens, logits) ]
# perform min_p filtering of our logits
"""
if min_p > 0.0:
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
@ -1291,17 +1295,19 @@ class Base_V2(nn.Module):
# do top-no logit processing
if top_no > 0.0:
logits = [ top_no_logits_processing(logit) for logit in logits ]
"""
probabilities = [ F.softmax(logit, dim=-1) for logit in logits ]
scores = [ torch.max(prob, -1)[0] for prob in probabilities ]
# argmax instead
if temperature <= 0.0:
res = [ logit.argmax(dim=-1) for logit in logits ]
res = [ prob.argmax(dim=-1) for prob in probabilities]
else:
res = [ Categorical(logits=logit / temperature).sample() for logit in logits ]
# calculate token probabilities
scores = [
[ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ]
for logit, tokens in zip(logits, res)
torch.tensor([ [ prob[b, i, token].item() for i, token in enumerate(tokens[b]) ] for b in range(prob.size(0)) ], device=device)
for prob, tokens in zip(probabilities, res)
]
return Sampled(res, logits, scores, entropy)