cleaned up parallel nar, i think it's slightly faster but even the smallest model is still slower than ar+nar-len-llama-8...
This commit is contained in:
parent
9a7458cf17
commit
8068f24e35
|
@ -263,7 +263,7 @@ class AR_NAR_V2(Base_V2):
|
|||
# fill with masked tokens (even though they get masked anyways)
|
||||
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
|
||||
# fill scores
|
||||
scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
scores = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
|
@ -280,10 +280,11 @@ class AR_NAR_V2(Base_V2):
|
|||
# proportion of tokens to remask
|
||||
remask_p = 1.0 / (max_steps * 2) if remasking else 0
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=0 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
|
||||
# normal masking
|
||||
# mask off inputs
|
||||
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||
resps_list = [ torch.stack([resp[:, l].scatter(0, indices.t()[l], self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||
# boolean mask
|
||||
is_masked = [ resps == self.mask_token for resps in resps_list ]
|
||||
# timestep inputs
|
||||
|
@ -327,53 +328,20 @@ class AR_NAR_V2(Base_V2):
|
|||
|
||||
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] )
|
||||
|
||||
l_scores = []
|
||||
l_resps_list = []
|
||||
# cringe hack because we're able to sample multiple levels at once
|
||||
for l in range(self.n_resp_levels):
|
||||
# sample with sampler settings
|
||||
filtered_sampled = super().sample(
|
||||
logits=[ logit[l] for logit in logits ],
|
||||
prev_list=[ resp[..., l] for resp in prev_list ],
|
||||
quant_levels=quant_levels,
|
||||
# sample with sampler settings
|
||||
sampled = super().sample(
|
||||
logits=logits,
|
||||
prev_list=resps_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
temperature=sampling_temperature,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
# retrieves unfiltered logits
|
||||
unfiltered_sampled = super().sample(
|
||||
logits=[ logit[l] for logit in logits ],
|
||||
prev_list=[ resp[..., l] for resp in prev_list ],
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=0.0,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
# get sampled tokens
|
||||
sampled_ids = filtered_sampled.ids
|
||||
# keep unmasked tokens
|
||||
l_resps_list.append([ torch.where( masked[..., l], input_ids, resps[..., l] ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ])
|
||||
# get probability scores
|
||||
l_scores.append([
|
||||
# conjugate to have worse scoring tokens picked for topk
|
||||
1.0 -
|
||||
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
|
||||
torch.where( masked[..., l], torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked[..., l].shape, device=device) )
|
||||
# use unmodified logit scores for this, as it offers better stability
|
||||
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
|
||||
])
|
||||
|
||||
resps_list = []
|
||||
scores = []
|
||||
|
||||
for batch_index in range(batch_size):
|
||||
score = sum([ l_scores[level][batch_index] for level in range(self.n_resp_levels) ]) / self.n_resp_levels
|
||||
resp = torch.stack([ l_resps_list[level][batch_index] for level in range(self.n_resp_levels) ], dim=-1)
|
||||
|
||||
scores.append( score )
|
||||
resps_list.append( resp )
|
||||
# update resps, filling in the masked tokens with the new tokens
|
||||
resps_list = [ torch.where( masked, ids.t(), resps ).to(torch.int16) for masked, ids, resps in zip( is_masked, sampled.ids, resps_list ) ]
|
||||
# update scores, filling in the
|
||||
scores = [ 1.0 - torch.where( masked, scores.t(), 1 ) for masked, scores in zip( is_masked, sampled.scores ) ]
|
||||
|
||||
return resps_list
|
||||
|
||||
|
|
|
@ -1236,8 +1236,9 @@ class Base_V2(nn.Module):
|
|||
|
||||
def sample(
|
||||
self,
|
||||
logits: list[Tensor], # logit scores
|
||||
prev_list: list[Tensor] | None = None, # logit scores
|
||||
logits: Tensor, # logit scores
|
||||
prev_list: Tensor | None = None,
|
||||
len_list: Tensor | None = None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
# yikes
|
||||
|
@ -1265,6 +1266,7 @@ class Base_V2(nn.Module):
|
|||
attentions = sampling_kwargs.get("attentions", None)
|
||||
|
||||
batch_size = len( logits )
|
||||
device = logits[0].device
|
||||
|
||||
if min_temperature < 0:
|
||||
min_temperature = temperature
|
||||
|
@ -1273,14 +1275,16 @@ class Base_V2(nn.Module):
|
|||
entropy = None
|
||||
|
||||
if prev_list is not None:
|
||||
seq_lens = map(len, prev_list)
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, seq_lens) ]
|
||||
# (AR chunkwise) return the last chunkwise piece
|
||||
seq_lens = [ prev.shape[0] for prev in prev_list ]
|
||||
elif len_list is not None:
|
||||
seq_lens = len_list
|
||||
elif self.causal:
|
||||
seq_lens = [ logit.shape[0] - self.causal_size for logit in logits ]
|
||||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||
seq_lens = [ self.causal_size for _ in range( batch_size) ]
|
||||
|
||||
logits = [ logit[..., -l:, :] for l, logit in zip(seq_lens, logits) ]
|
||||
|
||||
# perform min_p filtering of our logits
|
||||
"""
|
||||
if min_p > 0.0:
|
||||
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
|
||||
|
||||
|
@ -1291,17 +1295,19 @@ class Base_V2(nn.Module):
|
|||
# do top-no logit processing
|
||||
if top_no > 0.0:
|
||||
logits = [ top_no_logits_processing(logit) for logit in logits ]
|
||||
"""
|
||||
|
||||
probabilities = [ F.softmax(logit, dim=-1) for logit in logits ]
|
||||
scores = [ torch.max(prob, -1)[0] for prob in probabilities ]
|
||||
|
||||
# argmax instead
|
||||
if temperature <= 0.0:
|
||||
res = [ logit.argmax(dim=-1) for logit in logits ]
|
||||
res = [ prob.argmax(dim=-1) for prob in probabilities]
|
||||
else:
|
||||
res = [ Categorical(logits=logit / temperature).sample() for logit in logits ]
|
||||
|
||||
# calculate token probabilities
|
||||
scores = [
|
||||
[ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ]
|
||||
for logit, tokens in zip(logits, res)
|
||||
torch.tensor([ [ prob[b, i, token].item() for i, token in enumerate(tokens[b]) ] for b in range(prob.size(0)) ], device=device)
|
||||
for prob, tokens in zip(probabilities, res)
|
||||
]
|
||||
|
||||
return Sampled(res, logits, scores, entropy)
|
Loading…
Reference in New Issue
Block a user