From d6f7c86a5c28bf5e20a41878f790d7907832e765 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 12 Oct 2024 09:46:18 -0500 Subject: [PATCH] entropix tweaks (it doesn't output garbage but it loves to go for silence) --- vall_e/models/base.py | 137 ++++-------------------------------------- vall_e/samplers.py | 133 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 143 insertions(+), 127 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c2d976f..8ff356a 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1507,135 +1507,22 @@ class Base(nn.Module): elif self.causal: logits = [ logit[-self.causal_size:] for logit in logits ] - # calculate entropies - # I would love to shove it in samplers.py but we modify our sampler settings + # entropix sampling if attentions is not None: - entropy = [ calculate_entropix_metrics( logit, attn ) for logit, attn in zip(logits, attentions) ] - - if attentions is not None: - entropix_enabled = True - - # this might actually slow things down a bit slightly-er? + # move to CPU for speedups logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] - # to-do: not make it hardcoded to bsz=1 - metrics = entropy[0] - logit = logits[0] + res = [ sample_entropix( + logit, + attentions[-1], #torch.stack(attentions, dim=1), + temperature, + top_k, + top_p, + min_p, + ) for logit in logits ] - ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"] - attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"] - agreement = metrics["agreement"] - interaction_strength = metrics["interaction_strength"] - - # adjust sample settings - cfg = EntropixSamplerConfig() - - entropy[0]["action"] = -1 - # Low Entropy, Low Varentropy: "flowing with unspoken intent" - if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh: - entropy[0]["action"] = 0 - temperature *= 0 - # High Entropy, Low Varentropy: "treading carefully, asking clarifying questions" - elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh: - entropy[0]["action"] = 1 - # sample with slightly higher temperature - temperature *= cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy - # Low Entropy, High Varentropy: "exploring forks in the path" - elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh: - entropy[0]["action"] = 2 - temperature *= cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength - top_k = max(5, int(top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low - # High Entropy, High Varentropy: "resampling in the mist" - elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh: - entropy[0]["action"] = 3 - # Use high temperature and adjusted top_p based on attention metrics - temperature *= cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy - top_p = max(0.5, top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high - # Middle ground: use adaptive sampling - else: - entropy[0]["action"] = 4 - log_softmax = torch.nn.functional.log_softmax(logit) - logits_uncertainty = ent + vent - attn_uncertainty = attn_ent + attn_vent - - temperature = temperature * float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement) - top_p = float(torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0)) - top_k = int(torch.clip( - torch.round(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)), - min=cfg.top_k_min, - max=cfg.top_k_max - )) - min_p = float(torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5)) - temperature = clamp( temperature, cfg.temperature_min, cfg.temperature_max ) - - def _sample( logits ): - # perform repetition penalizing - if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0: - # to-do: figure out a faster way to handle tolist() - logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] - - # (AR) perform length penalizing - if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0: - logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] - - if min_p > 0.0: - logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ] - - # perform top_k/top_p filtering of our logits - if top_k > 0 or top_p < 1.0: - logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ] - - # trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature - # epsilon float comparison because I don't trust Python - if abs(temperature - min_temperature) >= 0.001: - logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ] - else: - logits = [ logit / temperature for logit in logits ] - - # do DRY sampling - if dry_multiplier > 0.0: - logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ] - - return [ Categorical(logits=logit).sample() for logit in logits ] - - if entropix_enabled: - samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ] - - def score_sample(sample): - one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1]) - log_prob = torch.sum(log_softmax * one_hot) - - confidence_score = ( - (1 - ent) * cfg.ada_score_logits_ent + - (1 - attn_ent) * cfg.ada_score_attn_ent + - (1 - vent) * cfg.ada_score_logits_vent + - (1 - attn_vent) * cfg.ada_score_attn_vent + - agreement * cfg.ada_score_agree + - interaction_strength * cfg.ada_score_int - ) - return log_prob + confidence_score - - sample_scores = [ score_sample(sample) for sample in samples ] - best_sample_idx = torch.argmax(torch.asarray(sample_scores)) - - res = samples[best_sample_idx] - scores = sample_scores - return Sampled(res, scores, entropy) - - temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max ) - min_temperature = temperature - - entropy[0]["temperature"] = temperature - entropy[0]["top_k"] = top_k - entropy[0]["top_p"] = top_p - entropy[0]["min_p"] = min_p - - if not entropix_enabled: - temperature = 1.0 - min_temperature = 1.0 - top_k = 0 - top_p = 1.0 - min_p = 0.0 + if res: + return Sampled([ r[0] for r in res], scores, [ r[1] for r in res]) # (NAR) disable stop token if quant_levels is not None and "ar" in self.capabilities: diff --git a/vall_e/samplers.py b/vall_e/samplers.py index b5f887e..0a2563c 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -258,6 +258,7 @@ def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ): "attn_varentropy": torch.mean(attn_varentropy), "agreement": torch.mean(agreement), "interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)), + "action": -1 } # to-do: play around with these values @@ -304,7 +305,135 @@ class EntropixSamplerConfig: ada_score_int: float = 0.6 # extra stuff + temperature_max: float = 1.25 + temperature_min: float = 0.5 top_k_min: int = 1 top_k_max: int = 1024 - temperature_max: float = 1.25 - temperature_min: float = 0.5 \ No newline at end of file + top_p_min: int = 0.1 + top_p_max: int = 1.0 + min_p_min: int = 0.01 + min_p_max: int = 0.5 + +Exponential = torch.distributions.exponential.Exponential(1.0) +def _sample_entropix( + logits, + temperature=1.0, + top_k=0, + top_p=1.0, + min_p=0.0, + cfg=EntropixSamplerConfig(), +): + def clamp(n, lo, hi): + return max(lo, min(n, hi)) + + if top_k == 0: + top_k = logits.shape[-1] + + temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max ) + top_p = clamp( float(top_p), cfg.top_p_min, cfg.top_p_max ) + top_k = clamp( int(top_k), cfg.top_k_min, cfg.top_k_max ) + min_p = clamp( float(min_p), cfg.min_p_min, cfg.min_p_max ) + + probs = torch.nn.functional.softmax(logits / temperature, dim=-1) + + # Apply min_p sampling + if min_p > 0.0: + p_max = float(torch.max(probs, dim=-1, keepdims=True).values) + indices_to_remove = probs < (min_p * p_max) + logits = torch.where(indices_to_remove, torch.full_like(logits, float('-inf')), logits) + + # Apply top-k sampling + top_k_probs, top_k_indices = torch.topk(probs, k=top_k) + probs_sort = torch.flip(top_k_probs, dims=[-1]) + probs_idx = torch.flip(top_k_indices, dims=[-1]) + probs_sum = torch.cumsum(probs_sort, dim=-1) + + # Apply top-p sampling + mask = torch.where(probs_sum - probs_sort > top_p, 1.0, 0.0) + probs_sort = probs_sort * (1 - mask) + probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdims=True) + + next_token = torch.argmax(probs_sort / Exponential.sample(probs_sort.shape), dim=-1, keepdim=True) + return torch.take_along_dim(probs_idx, next_token, dim=-1) + +def sample_entropix( + logits, + attentions, + temperature=1.0, + top_k=32, + top_p=1.0, + min_p=0.0, + cfg=EntropixSamplerConfig(), +): + metrics = calculate_entropix_metrics( logits, attentions ) + + ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"] + attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"] + agreement = metrics["agreement"] + interaction_strength = metrics["interaction_strength"] + + # Low Entropy, Low Varentropy: "flowing with unspoken intent" + if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh: + metrics["action"] = 0 + res = logits.argmax(dim=1) + # High Entropy, Low Varentropy: "treading carefully, asking clarifying questions" + elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh: + metrics["action"] = 1 + # sample with slightly higher temperature + temperature *= cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy + res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg ) + # Low Entropy, High Varentropy: "exploring forks in the path" + elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh: + metrics["action"] = 2 + temperature *= cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength + top_k = max(5, int(top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low + res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg ) + # High Entropy, High Varentropy: "resampling in the mist" + elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh: + metrics["action"] = 3 + # Use high temperature and adjusted top_p based on attention metrics + temperature *= cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy + top_p = max(0.5, top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high + res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg ) + # Middle ground: use adaptive sampling + else: + metrics["action"] = 4 + + log_softmax = torch.nn.functional.log_softmax(logits) + logits_uncertainty = ent + vent + attn_uncertainty = attn_ent + attn_vent + + temperature *= float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement) + top_p = float(top_p * (1 + cfg.ada_top_p * attn_vent)) + top_k = int(round(float(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)))) + min_p = float(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty)) + + samples = [ _sample_entropix( logits.clone(), temperature, top_k, top_p, min_p, cfg=cfg ) for _ in range(cfg.n_adaptive_samples) ] + + def score_sample(sample): + one_hot = torch.nn.functional.one_hot( sample, logits.shape[-1] ) + log_prob = torch.sum(log_softmax * one_hot) + + confidence_score = ( + (1 - ent) * cfg.ada_score_logits_ent + + (1 - attn_ent) * cfg.ada_score_attn_ent + + (1 - vent) * cfg.ada_score_logits_vent + + (1 - attn_vent) * cfg.ada_score_attn_vent + + agreement * cfg.ada_score_agree + + interaction_strength * cfg.ada_score_int + ) + return log_prob + confidence_score + + sample_scores = [ score_sample(sample) for sample in samples ] + best_sample_idx = torch.argmax(torch.asarray(sample_scores)) + + res = samples[best_sample_idx] + + """ + metrics["temperature"] = temperature + metrics["top_k"] = top_k + metrics["top_p"] = top_p + metrics["min_p"] = min_p + """ + + return res, metrics \ No newline at end of file