diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index d7ec30b..ea7807f 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -135,7 +135,7 @@ class LlamaAttention_Adapted(LlamaAttention): self.mode = kwargs['mode'] kwargs.pop("mode") else: - self.mode = "math" + self.mode = "sdpa" if self.mode == "math": self.mode = torch.nn.attention.SDPBackend.MATH @@ -145,8 +145,6 @@ class LlamaAttention_Adapted(LlamaAttention): self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION elif self.mode == "cudnn": self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION - else: - self.mode = None super().__init__(*args, **kwargs) @@ -330,6 +328,8 @@ class LlamaAttention_Adapted(LlamaAttention): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + + # to-do: actually find what is our attention scores, since these seem to not vary at all attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if output_attentions else None causal_mask = attention_mask diff --git a/vall_e/models/base.py b/vall_e/models/base.py index fa2ce42..0780e26 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1502,9 +1502,11 @@ class Base(nn.Module): # (NAR) return the entire generated response # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely - logits = [ logit[-l:] for logit, l in zip(logits, map(len, prev_list)) ] + seq_lens = map(len, prev_list) + logits = [ logit[-l:] for logit, l in zip(logits, seq_lens) ] # (AR chunkwise) return the last chunkwise piece elif self.causal: + seq_lens = [ logit.shape[0] - self.causal_size for logit in logits ] logits = [ logit[-self.causal_size:] for logit in logits ] # (NAR) disable stop token @@ -1537,12 +1539,12 @@ class Base(nn.Module): res = [ sample_entropix( logit, - attentions[-1], # original code just uses the last attention scores + torch.stack(attentions, dim=1)[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, ? ), our attention scores might be padded temperature, top_k, top_p, min_p, - ) for logit in logits ] + ) for batch, logit in enumerate(logits) ] if res: return Sampled([ r[0] for r in res], scores, [ r[1] for r in res]) diff --git a/vall_e/samplers.py b/vall_e/samplers.py index ece2892..f6a40ab 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -2,6 +2,7 @@ import math import torch import torch.nn.functional as F import numpy as np +import time from torch import Tensor, einsum, nn @@ -229,38 +230,88 @@ def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E # Grabbed from https://github.com/xjdr-alt/entropix/blob/main/entropix/sampler.py -# Right now I only care about quantifying these two, I'll figure out how to best apply this to the model -def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ): +def calculate_entropix_metrics( logits, attentions=None, dim=-1, use_stats=False ): """Calculate the entropy and varentropy of the probability distribution using logsoftmax.""" - log_probs = torch.nn.functional.log_softmax(logits, dim=dim) + log_probs = F.log_softmax(logits, dim=dim) probs = torch.exp(log_probs) entropy = -torch.sum(probs * log_probs, dim=dim) / LN_2 # Convert to base-2 - varentropy = torch.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, dim=dim) + varentropy = torch.sum(probs * (log_probs / LN_2 + entropy.unsqueeze(-1))**2, dim=dim) - if attention_scores is None: + if attentions is None: return { "logits_entropy": torch.mean(entropy).item(), "logits_varentropy": torch.mean(varentropy).item(), } - attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1) - attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clip(attention_probs, 1e-10, 1.0)), dim=-1) - attn_varentropy = torch.var(attn_entropy, dim=1) + last_attention_scores = attentions[-1].unsqueeze(0) # ( bsz, heads, seq_len, ? ) + attention_probs = F.softmax(last_attention_scores, dim=-1) + if use_stats: + attn_stats = AttnStats.new( 1, attentions.shape[0], attentions.shape[1], logits.device ) + for idx, attn in enumerate( attentions ): + attn_stats.update( attn.unsqueeze(0)[:, :, -1, :], idx ) # (bsz, heads, last_token, ?) + attn_entropy = attn_stats.entropy + attn_varentropy = attn_stats.varentropy + else: + attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1) + attn_varentropy = torch.var(attn_entropy, dim=1) + # Add a small epsilon to avoid NaN when all values are the same + attn_varentropy = torch.where(torch.isnan(attn_varentropy), torch.zeros_like(attn_varentropy), attn_varentropy) mean_attention = torch.mean(attention_probs, dim=1) - agreement = torch.mean(torch.abs(attention_probs - mean_attention[:, None, :]), dim=(1, 2)) + agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2)) + + interaction_strength = torch.mean(torch.abs(last_attention_scores), dim=(1, 2, 3)) - interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)) return { "logits_entropy": torch.mean(entropy).item(), "logits_varentropy": torch.mean(varentropy).item(), "attn_entropy": torch.mean(attn_entropy).item(), "attn_varentropy": torch.mean(attn_varentropy).item(), "agreement": torch.mean(agreement).item(), - "interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)).item(), + "interaction_strength": interaction_strength.item(), # torch.mean(interaction_strength).item(), "action": -1 } +from typing import NamedTuple +class AttnStats(NamedTuple): + entropy: torch.Tensor # (bsz, n_layers, num_heads) + varentropy: torch.Tensor # (bsz, n_layers, num_heads) + n_layers: int + n_heads: int + + @classmethod + def new(cls, bsz: int, n_layers: int, n_heads: int, device = "cuda") -> 'AttnStats': + return cls( + entropy=torch.zeros((bsz, n_layers, n_heads), dtype=torch.float32, device=device), + varentropy=torch.zeros((bsz, n_layers, n_heads), dtype=torch.float32, device=device), + n_layers=n_layers, + n_heads=n_heads + ) + + @property + def avg_entropy(self): + return self.entropy.sum(dim=-1, keepdim=False) # Average across heads + + @property + def avg_varentropy(self): + return self.varentropy.sum(dim=-1, keepdim=False) # Average across heads + + @property + def std_error(self): + return torch.sqrt(torch.mean(self.varentropy)) / (self.n_heads * self.n_layers) + + def update(self, scores: torch.Tensor, layer_idx: int): + # scores shape: (bsz, n_heads, seqlen, n_words) + probs = torch.nn.functional.softmax(scores, dim=-1) + new_entropy = -torch.sum(torch.where(probs > 0, probs * torch.log(probs), torch.tensor(0.0)), dim=-1) + new_varentropy = torch.sum(probs * (torch.log(probs) + new_entropy.unsqueeze(-1))**2, dim=-1) + + # Update entropy and varentropy tensors + self.entropy[:, layer_idx, :] = new_entropy + self.varentropy[:, layer_idx, :] = new_varentropy + + return self + # to-do: play around with these values @dataclass() class EntropixSamplerConfig: @@ -336,27 +387,36 @@ def _sample_entropix( top_k = clamp( int(top_k), cfg.top_k_min, cfg.top_k_max ) min_p = clamp( float(min_p), cfg.min_p_min, cfg.min_p_max ) - probs = torch.nn.functional.softmax(logits / temperature, dim=-1) + probs = F.softmax(logits / temperature, dim=-1) # Apply min_p sampling if min_p > 0.0: - p_max = float(torch.max(probs, dim=-1, keepdims=True).values) + p_max = float(torch.max(probs, dim=-1, keepdim=True).values) indices_to_remove = probs < (min_p * p_max) logits = torch.where(indices_to_remove, torch.full_like(logits, float('-inf')), logits) # Apply top-k sampling - top_k_probs, top_k_indices = torch.topk(probs, k=top_k) + top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1])) probs_sort = torch.flip(top_k_probs, dims=[-1]) probs_idx = torch.flip(top_k_indices, dims=[-1]) probs_sum = torch.cumsum(probs_sort, dim=-1) - # Apply top-p sampling - mask = torch.where(probs_sum - probs_sort > top_p, 1.0, 0.0) + mask = torch.where(probs_sum - probs_sort > top_p, torch.tensor(1.0, device=logits.device), torch.tensor(0.0, device=logits.device)) probs_sort = probs_sort * (1 - mask) - probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdims=True) + probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdim=True) - next_token = torch.argmax(probs_sort / Exponential.sample(probs_sort.shape), dim=-1, keepdim=True) + q = Exponential.sample(probs_sort.shape) + """ + next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True) return torch.take_along_dim(probs_idx, next_token, dim=-1)[0] + """ + + """ + # q = torch.rand(probs_sort.shape, generator=generator, device=probs_sort.device) + """ + next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True) + next_token_g = torch.take_along_dim(probs_idx, next_token, dim=-1)[0] + return next_token_g def sample_entropix( logits, @@ -373,6 +433,13 @@ def sample_entropix( top_p = cfg.top_p """ + """ + if generator is None: + generator = torch.Generator(device=logits.device).manual_seed(int(time.time())) + """ + + # logits: ( bsz, vocab ) + # attentions: ( bsz, layer, heads, seq_len, ? ) metrics = calculate_entropix_metrics( logits, attentions ) ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"] @@ -407,7 +474,7 @@ def sample_entropix( else: metrics["action"] = 4 - log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + log_softmax = F.log_softmax(logits, dim=-1) logits_uncertainty = ent + vent attn_uncertainty = attn_ent + attn_vent @@ -419,7 +486,7 @@ def sample_entropix( samples = [ _sample_entropix( logits.clone(), temperature, top_k, top_p, min_p, cfg=cfg ) for _ in range(cfg.n_adaptive_samples) ] def score_sample(sample): - one_hot = torch.nn.functional.one_hot( sample, logits.shape[-1] ) + one_hot = F.one_hot( sample, logits.shape[-1] ) log_prob = torch.sum(log_softmax * one_hot) confidence_score = (