From 84005c5b001d05c0628686278ba4df43082a489c Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 13 Oct 2024 12:01:12 -0500 Subject: [PATCH] entropix apparently processes the entire sequence of logits but it falls apart when doing that --- vall_e/models/arch/llama.py | 4 ++++ vall_e/models/base.py | 38 +++++++++++++++++++----------------- vall_e/samplers.py | 39 ++++++++++++++++++------------------- 3 files changed, 43 insertions(+), 38 deletions(-) diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 0c4c056..7a587ae 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -292,6 +292,10 @@ class LlamaAttention_Adapted(LlamaAttention): is_causal=is_causal, ) + # cringe + if attn_scores is None and output_attentions: + attn_scores = attn_output + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 778e208..cb08f2e 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1498,6 +1498,26 @@ class Base(nn.Module): scores = None entropy = None + # (AR) entropix sampling + # we do it before everything to retain logits for the entire sequence (even though it's still better to pass only the last token) + if attentions is not None and quant_levels is None: + # move to CPU for speedups + seq_lens = [ logit.shape[0] for logit in logits ] + logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] + attentions = torch.stack(attentions, dim=1).to(device="cpu") # ( batch, layer, heads, seq_len, seq_len ) + + res = [ sample_entropix( + logit[:seq_lens[batch], :], # ( seq_len, vocab ) + attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len ) + temperature, + top_k, + top_p, + min_p, + ) for batch, logit in enumerate(logits) ] + + if res: + return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ]) + # (NAR) return the entire generated response # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely @@ -1530,24 +1550,6 @@ class Base(nn.Module): if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0: logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] - # (AR) entropix sampling - # we do it after the penalizers because entropix's internal sampling doesn't account for them (but does do top_k/top_p/min_p) - if attentions is not None and quant_levels is None: - # move to CPU for speedups - logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] - - res = [ sample_entropix( - logit, - torch.stack(attentions, dim=1)[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, ? ), our attention scores might be padded - temperature, - top_k, - top_p, - min_p, - ) for batch, logit in enumerate(logits) ] - - if res: - return Sampled([ r[0] for r in res], scores, [ r[1] for r in res]) - # perform min_p filtering of our logits if min_p > 0.0: logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ] diff --git a/vall_e/samplers.py b/vall_e/samplers.py index f6a40ab..419325a 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -243,12 +243,12 @@ def calculate_entropix_metrics( logits, attentions=None, dim=-1, use_stats=False "logits_varentropy": torch.mean(varentropy).item(), } - last_attention_scores = attentions[-1].unsqueeze(0) # ( bsz, heads, seq_len, ? ) + last_attention_scores = attentions[-1].unsqueeze(0) # ( bsz, heads, seq_len, seq_len ) attention_probs = F.softmax(last_attention_scores, dim=-1) if use_stats: attn_stats = AttnStats.new( 1, attentions.shape[0], attentions.shape[1], logits.device ) for idx, attn in enumerate( attentions ): - attn_stats.update( attn.unsqueeze(0)[:, :, -1, :], idx ) # (bsz, heads, last_token, ?) + attn_stats.update( attn.unsqueeze(0)[:, :, -1, :], idx ) # (bsz, heads, last_token, seq_len) attn_entropy = attn_stats.entropy attn_varentropy = attn_stats.varentropy else: @@ -382,18 +382,20 @@ def _sample_entropix( if top_k == 0: top_k = logits.shape[-1] + logit = logits[-1, :] + temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max ) top_p = clamp( float(top_p), cfg.top_p_min, cfg.top_p_max ) top_k = clamp( int(top_k), cfg.top_k_min, cfg.top_k_max ) min_p = clamp( float(min_p), cfg.min_p_min, cfg.min_p_max ) - probs = F.softmax(logits / temperature, dim=-1) + probs = F.softmax(logit / temperature, dim=-1) # Apply min_p sampling if min_p > 0.0: p_max = float(torch.max(probs, dim=-1, keepdim=True).values) indices_to_remove = probs < (min_p * p_max) - logits = torch.where(indices_to_remove, torch.full_like(logits, float('-inf')), logits) + logit = torch.where(indices_to_remove, torch.full_like(logit, float('-inf')), logit) # Apply top-k sampling top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1])) @@ -401,21 +403,16 @@ def _sample_entropix( probs_idx = torch.flip(top_k_indices, dims=[-1]) probs_sum = torch.cumsum(probs_sort, dim=-1) # Apply top-p sampling - mask = torch.where(probs_sum - probs_sort > top_p, torch.tensor(1.0, device=logits.device), torch.tensor(0.0, device=logits.device)) + mask = torch.where(probs_sum - probs_sort > top_p, torch.tensor(1.0, device=logit.device), torch.tensor(0.0, device=logit.device)) probs_sort = probs_sort * (1 - mask) probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdim=True) q = Exponential.sample(probs_sort.shape) - """ - next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True) - return torch.take_along_dim(probs_idx, next_token, dim=-1)[0] - """ - """ # q = torch.rand(probs_sort.shape, generator=generator, device=probs_sort.device) """ next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True) - next_token_g = torch.take_along_dim(probs_idx, next_token, dim=-1)[0] + next_token_g = torch.take_along_dim(probs_idx, next_token, dim=-1) return next_token_g def sample_entropix( @@ -433,14 +430,9 @@ def sample_entropix( top_p = cfg.top_p """ - """ - if generator is None: - generator = torch.Generator(device=logits.device).manual_seed(int(time.time())) - """ - - # logits: ( bsz, vocab ) - # attentions: ( bsz, layer, heads, seq_len, ? ) - metrics = calculate_entropix_metrics( logits, attentions ) + # logits: ( seq_len, vocab ) + # attentions: ( layer, heads, seq_len, seq_len ) + metrics = calculate_entropix_metrics( logits[-1:, :], attentions[:, :, -1:, :] ) ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"] attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"] @@ -450,7 +442,7 @@ def sample_entropix( # Low Entropy, Low Varentropy: "flowing with unspoken intent" if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh: metrics["action"] = 0 - res = logits.argmax(dim=1) + res = logits[-1, :].argmax(dim=1) # High Entropy, Low Varentropy: "treading carefully, asking clarifying questions" elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh: metrics["action"] = 1 @@ -510,6 +502,13 @@ def sample_entropix( res = samples[best_sample_idx] + """ + metrics = { + "attn_entropy": metrics["attn_entropy"], + "attn_varentropy": metrics["attn_varentropy"], + } + """ + """ metrics["temperature"] = temperature metrics["top_k"] = top_k