From bef43a0c18176e784412512fb6d914e1ff97d495 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 11 Oct 2024 21:18:26 -0500 Subject: [PATCH] added experimental entropix sampling support --- vall_e/config.py | 2 + vall_e/data.py | 2 +- vall_e/demo.py | 2 +- vall_e/emb/process.py | 38 ++++++--- vall_e/models/ar.py | 29 +++---- vall_e/models/ar_nar.py | 43 ++++++---- vall_e/models/base.py | 177 ++++++++++++++++++++++++++++++++++------ vall_e/models/nar.py | 9 +- vall_e/plot.py | 23 ++++++ vall_e/samplers.py | 86 ++++++++++++++++++- 10 files changed, 337 insertions(+), 74 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 5a1ff98..9aa48d6 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -237,6 +237,8 @@ class ModelExperimentalSettings: p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len # to-to: just incorporate this as a task instead + entropix_sampling: bool = False # experimental sampling based on https://github.com/xjdr-alt/entropix, experimental flag because it requires using naive attention for output scores + # I really need to clean this up @dataclass() class Model: diff --git a/vall_e/data.py b/vall_e/data.py index bb8d23a..8fa6ae1 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1458,7 +1458,7 @@ def process_artifact_metadata( artifact ): metadata["similar"] = artifact["metadata"]["similar"] # duration for use of culling / sorting dataset if "duration" in artifact["metadata"]: - metadata["duration"] = duration + metadata["duration"] = float(artifact["metadata"]["duration"]) # derive duration from sample count / sample rate elif "original_length" in artifact["metadata"] and "sample_rate" in artifact["metadata"]: metadata["duration"] = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] diff --git a/vall_e/demo.py b/vall_e/demo.py index f56201e..0011723 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -184,7 +184,7 @@ def main(): extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else []) - if not args.random_prompts: + if not args.random_prompts or k == "librispeech": extra_sources += [ reference ] samples.append(( diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index b49e746..dee037f 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -196,10 +196,13 @@ def process( for filename in sorted(metadata.keys()): inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') + + """ if not inpath.exists(): missing["audio"].append(str(inpath)) continue - + """ + extension = os.path.splitext(filename)[-1][1:] fname = filename.replace(f'.{extension}', "") @@ -220,10 +223,19 @@ def process( jobs.append(( outpath, waveform, sample_rate, text, language )) else: i = 0 + presliced = not inpath.exists() + for segment in metadata[filename]["segments"]: id = pad(i, 4) i = i + 1 + if presliced: + inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{fname}_{id}.{extension}') + + if not inpath.exists(): + missing["audio"].append(str(inpath)) + continue + outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}').with_suffix(audio_extension) text = segment["text"] @@ -234,18 +246,19 @@ def process( if waveform is None: waveform, sample_rate = load_audio( inpath ) - start = int(segment['start'] * sample_rate) - end = int(segment['end'] * sample_rate) + start = int((segment['start']-0.05) * sample_rate) + end = int((segment['end']+0.5) * sample_rate) - if start < 0: - start = 0 - if end >= waveform.shape[-1]: - end = waveform.shape[-1] - 1 + if not presliced: + if start < 0: + start = 0 + if end >= waveform.shape[-1]: + end = waveform.shape[-1] - 1 - if end - start < 0: - continue + if end - start < 0: + continue - jobs.append(( outpath, waveform[:, start:end], sample_rate, text, language )) + jobs.append(( outpath, waveform if presliced else waveform[:, start:end], sample_rate, text, language )) # processes audio files one at a time if low_memory: @@ -287,6 +300,11 @@ def main(): args.stride_offset = int(args.device) args.device = f'cuda:{args.device}' + if args.slice == "true": + args.slice = True + elif args.slice == "false": + args.slice = False + process( audio_backend=args.audio_backend, input_audio=args.input_audio, diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 3d59da8..06d9897 100644 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -188,18 +188,13 @@ class AR(Base): quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ) - if state is not None: - logits, state = super().forward( - inputs=inputs, - state=state, - ) - else: - logits = super().forward( - inputs=inputs, - state=state, - ) + output = super().forward( + inputs=inputs, + state=state, + ) + logits, state = output.logits, output.state - r = super().sample( + sampled = super().sample( logits=logits, prev_list=resps_list, @@ -219,15 +214,13 @@ class AR(Base): dry_allowed_length=sampling_dry_allowed_length, ) + r = sampled[0] + if mirostat is not None: - # r is the state - mirostat = r - # extract token from state - r = [ state["token"] for state in mirostat ] - # we do it here because the sampler will already expand our logits list + mirostat = sampled.scores elif sampling_beam_width > 0: # expand tuple - r, s = r + scores = sampled.scores # first step, expand batch if batch_size == 1: batch_size = sampling_beam_width @@ -236,7 +229,7 @@ class AR(Base): sequence_list = sequence_list * sampling_beam_width stopped = torch.zeros(batch_size, device=device).bool() - scores = [ scores[i] + score for i, score in enumerate(s) ] + scores = [ scores[i] + score for i, score in enumerate(scores) ] # append tokens for i, ri in enumerate(r): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 5e08c3f..f95b764 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -64,6 +64,8 @@ class AR_NAR(Base): sampling_dry_base=1.75, sampling_dry_allowed_length=2, + sampling_entropix=None, + disable_tqdm=False, use_lora=None, ): @@ -222,11 +224,9 @@ class AR_NAR(Base): inputs=inputs, quant_levels=quant_levels, ) - if not isinstance( output, tuple ): - output = (output, None) - logits, state = output + logits, state = output.logits, output.state - resps_list = super().sample( + sampled = super().sample( logits=logits, prev_list=prev_list, quant_levels=quant_levels, @@ -242,6 +242,8 @@ class AR_NAR(Base): #mirostat=mirostat, ) + resps_list = sampled[0] + prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] return prev_list @@ -264,6 +266,10 @@ class AR_NAR(Base): ] * batch_size if sampling_mirostat_tau > 0.0 else None scores = [ 1.0 ] * sampling_beam_width + entropies = [] + + if sampling_entropix is None: + sampling_entropix = self.config.experimental.entropix_sampling for i, sequence in enumerate( sequence_list ): # add to text for STT @@ -296,13 +302,11 @@ class AR_NAR(Base): output = super().forward( inputs=inputs, state=state, + output_attentions=sampling_entropix, ) - if not isinstance( output, tuple ): - output = (output, None) - - logits, state = output + logits, state = output.logits, output.state - r = super().sample( + sampled = super().sample( logits=logits, prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], @@ -320,17 +324,20 @@ class AR_NAR(Base): dry_multiplier=sampling_dry_multiplier, dry_base=sampling_dry_base, dry_allowed_length=sampling_dry_allowed_length, + + attentions=output.attentions if sampling_entropix else None, ) + r = sampled[0] + + if sampled.entropy: + entropies.append( sampled.entropy ) + if mirostat is not None: - # r is the state - mirostat = r - # extract token from state - r = [ state["token"] for state in mirostat ] - # we do it here because the sampler will already expand our logits list + mirostat = sampled.scores elif sampling_beam_width > 0: # expand tuple - r, s = r + scores = sampled.scores # first step, expand batch if batch_size == 1: batch_size = sampling_beam_width @@ -339,7 +346,7 @@ class AR_NAR(Base): sequence_list = sequence_list * sampling_beam_width stopped = torch.zeros(batch_size, device=device).bool() - scores = [ scores[i] + score for i, score in enumerate(s) ] + scores = [ scores[i] + score for i, score in enumerate(scores) ] # append tokens for i, ri in enumerate(r): @@ -354,6 +361,10 @@ class AR_NAR(Base): if stopped.all().item(): break + if entropies: + from ..plot import plot_entropies + plot_entropies( entropies ) + # pick the best scoring candidate # desu this is always going to be candidate 0 if sampling_beam_width: diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 98de991..0603643 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -15,8 +15,9 @@ import torch.nn.functional as F import random import numpy as np import re -from time import perf_counter +from time import perf_counter +from collections import namedtuple from typing import Literal, overload, Optional, Tuple from functools import partial from einops import rearrange @@ -37,6 +38,9 @@ from ..emb.qnt import encode_as_embedding # yuck, kind of needed from ..data import get_task_symmap +Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions']) +Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict + """ from ..utils.pattern import DelayedPatternProvider, VALLEPattern """ @@ -805,11 +809,15 @@ class Base(nn.Module): inputs, mask = None, position_ids = None, + state = None, + output_attentions = False, ): x = inputs m = mask.squeeze(-1).int() + aux_loss = None + attentions = None # HF transformer derived model if self.arch_type in ["llama", "mistral", "mixtral"]: @@ -819,22 +827,25 @@ class Base(nn.Module): past_key_values=state, position_ids=position_ids, use_cache=not self.training, - # return_dict=True, + output_attentions=output_attentions, + return_dict=True, ) if self.n_experts > 1 and self.training: kwargs["output_router_logits"] = True - t = self.model(**kwargs) - - x = t[0] + output = self.model(**kwargs) + x = output["last_hidden_state"] # to-do: figure out why KV caching doesn't work #if not self.training: if state is not None: - state = t[1] + state = output["past_key_values"] + + if output_attentions: + attentions = output["attentions"] if self.n_experts > 1 and self.training: - router_logits = t[-1] + router_logits = output["aux_loss"] aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok ) elif self.arch_type == "transformer": # ensures we specify a quant_level for the transformer implementation's AdaLN @@ -895,7 +906,7 @@ class Base(nn.Module): if self.classifier is not None: x = self.classifier(x) * mask - return x, state, aux_loss + return Logits(x, state, aux_loss, attentions) # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation def inputs( @@ -1390,6 +1401,7 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None, state: dict | list | None = None, + output_attentions = False, ): x_list = self.inputs_to_embeddings( inputs, quant_levels ) x, m = list_to_tensor(x_list) @@ -1420,32 +1432,36 @@ class Base(nn.Module): # needs to be done here as we still have our raw inputs position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None - x, state, aux_loss = self._forward( + output = self._forward( inputs=x, mask=m, state=state, position_ids=position_ids, + output_attentions = output_attentions, ) + logits = output.logits + # to-do: piece-wise classification, now that there's a head for text # although again, one single monolithic head would be preferable instead...... if self.classifiers is not None: special_tasks = [ "len", "stt" ] classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ] - x = self.classifiers(x, levels = classifier_quant_levels) * m + logits = self.classifiers(logits, levels = classifier_quant_levels) * m # Remove padding - logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ] + logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ] # compute loss if the target is given if training: self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) # include any additional losses (for example: MoE router) - if aux_loss is not None: - self.loss["aux_loss"] = aux_loss + if output.aux_loss is not None: + self.loss["aux_loss"] = output.aux_loss - return (logits, state) if state is not None else logits + # rewrap, because we're modifying the logits here + return Logits(logits, output.state, output.aux_loss, output.attentions) def sample( self, @@ -1470,10 +1486,15 @@ class Base(nn.Module): dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, + # other + attentions=None, ): if min_temperature < 0: min_temperature = temperature + scores = None + entropy = None + # (NAR) return the entire generated response # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely @@ -1482,9 +1503,114 @@ class Base(nn.Module): elif self.causal: logits = [ logit[-self.causal_size:] for logit in logits ] - # this might actually slow things down a bit slightly-er? - #logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] - + # calculate entropies + # I would love to shove it in samplers.py but we modify our sampler settings + if attentions is not None: + entropy = [ calculate_entropix_metrics( logit, attn ) for logit, attn in zip(logits, attentions) ] + + # this might actually slow things down a bit slightly-er? + logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] + + # to-do: not make it hardcoded to bsz=1 + metrics = entropy[0] + logit = logits[0] + + ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"] + attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"] + agreement = metrics["agreement"] + interaction_strength = metrics["interaction_strength"] + + # adjust sample settings + cfg = EntropixSamplerConfig() + + # Low Entropy, Low Varentropy: "flowing with unspoken intent" + if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh: + entropy[0]["action"] = 0 + temperature *= 0 + # High Entropy, Low Varentropy: "treading carefully, asking clarifying questions" + elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh: + entropy[0]["action"] = 1 + # sample with slightly higher temperature + temperature *= cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy + # Low Entropy, High Varentropy: "exploring forks in the path" + elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh: + entropy[0]["action"] = 2 + temperature *= cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength + top_k = max(5, int(top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low + # High Entropy, High Varentropy: "resampling in the mist" + elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh: + entropy[0]["action"] = 3 + # Use high temperature and adjusted top_p based on attention metrics + temperature *= cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy + top_p = max(0.5, top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high + # Middle ground: use adaptive sampling + else: + entropy[0]["action"] = 4 + log_softmax = torch.nn.functional.log_softmax(logit) + logits_uncertainty = ent + vent + attn_uncertainty = attn_ent + attn_vent + + temperature = temperature * float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement) + top_p = torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0).item() + top_k = int(torch.clip( + torch.round(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)), + min=cfg.top_k_min, + max=cfg.top_k_max + )) + min_p = torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5) + + def _sample( logits ): + # perform repetition penalizing + if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0: + # to-do: figure out a faster way to handle tolist() + logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] + + # (AR) perform length penalizing + if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0: + logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] + + # perform top_k/top_p filtering of our logits + if top_k > 0 or top_p < 1.0: + logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ] + + # trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature + # epsilon float comparison because I don't trust Python + if abs(temperature - min_temperature) >= 0.001: + logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ] + else: + logits = [ logit / temperature for logit in logits ] + + # do DRY sampling + if dry_multiplier > 0.0: + logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ] + + return [ Categorical(logits=logit).sample() for logit in logits ] + + samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ] + + def score_sample(sample): + one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1]) + log_prob = torch.sum(log_softmax * one_hot) + + confidence_score = ( + (1 - ent) * cfg.ada_score_logits_ent + + (1 - attn_ent) * cfg.ada_score_attn_ent + + (1 - vent) * cfg.ada_score_logits_vent + + (1 - attn_vent) * cfg.ada_score_attn_vent + + agreement * cfg.ada_score_agree + + interaction_strength * cfg.ada_score_int + ) + return log_prob + confidence_score + + sample_scores = [ score_sample(sample) for sample in samples ] + best_sample_idx = torch.argmax(torch.asarray(sample_scores)) + + res = samples[best_sample_idx] + scores = sample_scores + return Sampled(res, scores, entropy) + + temperature = min(1.5, float(temperature)) + # (NAR) disable stop token if quant_levels is not None and "ar" in self.capabilities: logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ] @@ -1494,7 +1620,9 @@ class Base(nn.Module): # argmax instead if temperature <= 0.0: - return [ logit.argmax(dim=1) for logit in logits ] + res = [ logit.argmax(dim=1) for logit in logits ] + scores = None + return Sampled(res, scores, entropy) # perform repetition penalizing if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0: @@ -1524,17 +1652,18 @@ class Base(nn.Module): # currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work if mirostat is not None: # mirostat sampling - return [ mirostat_sample(logit, state=state) for logit, state in zip(logits, mirostat) ] - + scores = [ mirostat_sample(logit, state=state) for logit, state in zip(logits, mirostat) ] + res = [ state["token"] for state in scores ] # do beam search (naive implementation) # picks the top-k across all batches, and re-batches those resultant tokens # returns the logit scores as well to be P-concatted with the previous scores # to-do: not naively implement beam searching - if beam_width > 1: + elif beam_width > 1: candidates = top_k_logits_list( logits, beam_width ) res = [ torch.tensor(token, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] scores = [ logits[batch].flatten()[token] for batch, token in candidates ] - return res, scores + # basic sampling + else: + res = [ Categorical(logits=logit).sample() for logit in logits ] - # and sample - return [ Categorical(logits=logit).sample() for logit in logits ] \ No newline at end of file + return Sampled(res, scores, entropy) \ No newline at end of file diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 5ba87ce..6e07bbb 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -172,16 +172,17 @@ class NAR(Base): quant_levels=quant_levels, ) - logits = super().forward( + output = super().forward( inputs=inputs, quant_levels=quant_levels, ) + logits = output.logits """ resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ] """ - resps_list = super().sample( + sampled = super().sample( logits=logits, prev_list=prev_list, quant_levels=quant_levels, @@ -196,6 +197,7 @@ class NAR(Base): #beam_width=sampling_beam_width, #mirostat=mirostat, ) + resps_list = sampled[0] if n == 0: prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ] @@ -225,9 +227,10 @@ class NAR(Base): quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ) - logits = super().forward( + output = super().forward( inputs=inputs, ) + logits = output.logits r = [ logit[-1:].argmax(dim=1) for logit in logits ] # sanitize diff --git a/vall_e/plot.py b/vall_e/plot.py index 72f1080..ee7ca01 100644 --- a/vall_e/plot.py +++ b/vall_e/plot.py @@ -93,6 +93,29 @@ def plot(paths, args): #bbox_to_anchor=(1.04, 0.5), ) +def plot_entropies( entropies ): + """ + fig = plt.figure() + fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second ) + """ + + data = {} + + for key in entropies[0][0].keys(): + data[key] = [ e[0][key].item() if hasattr( e[0][key], "item" ) else e[0][key] for e in entropies ] + + df = pd.DataFrame(data) + df.plot() + + plt.gca().legend( + #loc="center left", + fancybox=True, + shadow=True, + #bbox_to_anchor=(1.04, 0.5), + ) + + out_path = cfg.rel_path / "metrics.png" + plt.savefig(out_path, bbox_inches="tight") if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 74ef9f0..f99bee2 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -5,6 +5,8 @@ import numpy as np from torch import Tensor, einsum, nn +from dataclasses import asdict, dataclass, field + # Simple filter to modify a token's probability if it shows up in the past # `one_time` will only apply the penalty once # `decay` is a factor that will exponentially apply to how far away it is @@ -201,4 +203,86 @@ def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 break logits[:, token] -= factor * base ** (length - allowed_length) - return logits \ No newline at end of file + return logits + + +LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E + +# Grabbed from https://github.com/xjdr-alt/entropix/blob/main/entropix/sampler.py +# Right now I only care about quantifying these two, I'll figure out how to best apply this to the model +def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ): + """Calculate the entropy and varentropy of the probability distribution using logsoftmax.""" + log_probs = torch.nn.functional.log_softmax(logits, dim=dim) + probs = torch.exp(log_probs) + entropy = -torch.sum(probs * log_probs, dim=dim) / LN_2 # Convert to base-2 + varentropy = torch.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, dim=dim) + + if attention_scores is None: + return { + "logits_entropy": torch.mean(entropy).item(), + "logits_varentropy": torch.mean(varentropy).item(), + } + + attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1) + attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clip(attention_probs, 1e-10, 1.0)), dim=-1) + attn_varentropy = torch.var(attn_entropy, dim=1) + + mean_attention = torch.mean(attention_probs, dim=1) + agreement = torch.mean(torch.abs(attention_probs - mean_attention[:, None, :]), dim=(1, 2)) + + interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)) + return { + "logits_entropy": torch.mean(entropy), + "logits_varentropy": torch.mean(varentropy), + "attn_entropy": torch.mean(attn_entropy), + "attn_varentropy": torch.mean(attn_varentropy), + "agreement": torch.mean(agreement), + "interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)), + } + +# to-do: play around with these values +@dataclass() +class EntropixSamplerConfig: + temp: float = 0.999 + top_p: float = 0.90 + top_k: int = 32 + min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth + + low_ent_thresh: float = 0.1 + low_vent_thresh: float = 0.1 + med_ent_thresh: float = 3.0 + high_ent_thresh: float = 5.0 + high_vent_thresh: float = 5.0 + + # TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible + helv_attn_ent_offset: float = 1.3 + helv_attn_ent_coef: float = 0.2 + + lehv_interaction_strength_offset: float = 1.2 + lehv_interaction_strength_coef: float = 0.3 + + hehv_attn_ent_coef: float = 0.2 + hehv_attn_vent_offset: float = 2.0 + hehv_attn_vent_coef: float = 0.5 + + # TODO not convinced this should + n_adaptive_samples: int = 5 + + # Adaptive sampling parameters + ada_temp_logits: float = 0.3 + ada_temp_attn: float = 0.2 + ada_temp_agree: float = 0.2 + ada_top_p: float = 0.1 + ada_top_k_int: float = 0.3 + ada_top_k_agree: float = 0.2 + ada_min_p: float = 0.5 + ada_score_logits_ent: float = 0.1 + ada_score_attn_ent: float = 0.2 + ada_score_logits_vent: float = 0.3 + ada_score_attn_vent: float = 0.4 + ada_score_agree: float = 0.5 + ada_score_int: float = 0.6 + + # extra stuff + top_k_min: int = 32 + top_k_max: int = 128 \ No newline at end of file