2023-10-11 17:25:31 +00:00
import math
import torch
import torch . nn . functional as F
import numpy as np
2024-10-13 04:53:13 +00:00
import time
2023-10-11 17:25:31 +00:00
from torch import Tensor , einsum , nn
2024-11-08 03:19:14 +00:00
from einops import rearrange
2024-10-12 02:18:26 +00:00
from dataclasses import asdict , dataclass , field
2024-11-10 04:57:34 +00:00
from . utils import clamp
2024-11-09 21:07:43 +00:00
2023-10-11 17:25:31 +00:00
# Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is
2024-11-12 18:49:53 +00:00
def reptition_penalize ( logits , previous = None , factor = 1.0 , decay = 0.0 , one_time = True , limit = None ) :
2024-11-12 02:35:08 +00:00
if factor == 1.0 or previous is None :
return logits
2024-11-12 03:40:19 +00:00
2024-11-12 02:35:08 +00:00
unique = set ( )
2024-11-12 03:40:19 +00:00
is_nar = previous . shape [ 0 ] == logits . shape [ 0 ]
for i , token in enumerate ( previous ) :
distance = previous . shape [ 0 ] - i
2024-11-12 02:35:08 +00:00
# rep-pen range
if limit and distance > = limit :
continue
# skip if we're only applying the decay once
if one_time and token in unique :
continue
2024-11-12 03:40:19 +00:00
start = None
end = None
# apply only to future tokens
if is_nar and i < logits . shape [ 0 ] :
start = i + 1
2024-11-12 18:49:53 +00:00
if limit :
end = i + limit
2024-11-12 03:40:19 +00:00
logits [ start : end , token ] / = factor * ( distance * * decay )
2024-11-12 02:35:08 +00:00
# add to set if we care about it
if one_time :
unique . add ( token )
return logits
2023-10-11 17:25:31 +00:00
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
# `length` is the length of the sequence currently
# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences
# `token` is the stop token.
def length_penalize ( logits , length , factor = 0.0 , token = - 1 ) :
if factor == 0.0 :
return logits
logits [ : , token ] / = ( length * * factor )
return logits
2024-06-18 03:14:43 +00:00
# Simple way to ban tokens
def ban_tokens ( logits , tokens ) :
for token in tokens :
2024-07-19 20:33:31 +00:00
# token not in logits
if logits . shape [ - 1 ] > = token :
continue
2024-06-18 03:14:43 +00:00
logits [ : , token ] = - float ( " inf " )
return logits
2024-10-12 03:36:06 +00:00
# Performs min_p filtering
# From https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/generation/logits_process.py#L537
def min_p_filtering ( logits , min_p = 0.0 , min_tokens_to_keep = 32 ) :
if min_p < = 0.0 :
return logits
# Convert logits to probabilities
probs = torch . softmax ( logits , dim = - 1 )
# Get the probability of the top token for each sequence in the batch
top_probs , _ = probs . max ( dim = - 1 , keepdim = True )
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
scaled_min_p = min_p * top_probs
sorted_indices = torch . argsort ( logits , descending = True , dim = - 1 )
sorted_indices_to_remove = torch . gather ( probs < scaled_min_p , dim = - 1 , index = sorted_indices )
sorted_indices_to_remove [ . . . , : min_tokens_to_keep ] = False
indices_to_remove = sorted_indices_to_remove . scatter ( 1 , sorted_indices , sorted_indices_to_remove )
return logits . masked_fill ( indices_to_remove , - float ( " inf " ) )
2023-10-11 17:25:31 +00:00
# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def top_k_top_p_filtering ( logits , top_k = 0 , top_p = 1.0 , filter_value = - float ( " Inf " ) , min_tokens = 1 ) :
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args :
logits : logits distribution shape ( batch size , vocabulary size )
if top_k > 0 : keep only top k tokens with highest probability ( top - k filtering ) .
if top_p < 1.0 : keep the top tokens with cumulative probability > = top_p ( nucleus filtering ) .
Nucleus filtering is described in Holtzman et al . ( http : / / arxiv . org / abs / 1904.09751 )
Make sure we keep at least min_tokens per batch example in the output
"""
if top_k > 0 :
top_k = min ( max ( top_k , min_tokens ) , logits . size ( - 1 ) ) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch . topk ( logits , top_k ) [ 0 ] [ . . . , - 1 , None ]
logits [ indices_to_remove ] = filter_value
if top_p < 1.0 :
sorted_logits , sorted_indices = torch . sort ( logits , descending = True )
cumulative_probs = torch . cumsum ( F . softmax ( sorted_logits , dim = - 1 ) , dim = - 1 )
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens > 1 :
# Keep at least min_tokens (set to min_tokens-1 because we add the first one below)
sorted_indices_to_remove [ . . . , : min_tokens ] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove [ . . . , 1 : ] = sorted_indices_to_remove [ . . . , : - 1 ] . clone ( )
sorted_indices_to_remove [ . . . , 0 ] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove . scatter ( 1 , sorted_indices , sorted_indices_to_remove )
logits [ indices_to_remove ] = filter_value
return logits
# credit to https://github.com/LostRuins/koboldcpp/pull/464 // https://github.com/kalomaze/koboldcpp/tree/dynamic-temp
def dynamic_temperature ( logits , temperature = 1.0 , min_temperature = 0.0 , k = 10 , sigmoidCenterPoint = 0.5 ) :
# loop over logits[:], as the NAR will have logits.shape[0] > 1
for i in range ( logits . shape [ 0 ] ) :
sum_exp = 0.0
maximum = torch . max ( logits [ i ] )
for logit in logits [ i ] :
sum_exp + = math . exp ( logit - maximum )
prob_max_token_before_temp = 1.0 / sum_exp
dynamic_temperature = temperature - ( temperature - min_temperature ) / ( 1 + math . exp ( - k * ( prob_max_token_before_temp - sigmoidCenterPoint ) ) )
logits [ i ] / = dynamic_temperature
return logits
# picks the top K tokens amongst a batch of logits
# logits: [Tensor] list of logits
# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from
def top_k_logits_list ( logits_list , k ) :
# ( batch, tokens ) => ( batch x tokens )
logits = torch . cat ( logits_list )
candidates = list ( torch . topk ( logits . flatten ( ) , k ) . indices . tolist ( ) ) # perform top-k across all logits
for i , index in enumerate ( candidates ) :
t = [ ]
N = np . prod ( logits . size ( ) )
for n in logits . size ( ) :
N / / = n
t . append ( index / / N )
index % = N
candidates [ i ] = tuple ( t )
return candidates
2024-11-13 04:30:09 +00:00
# top-nσ logit processing
# from https://arxiv.org/abs/2411.07641
def top_no_logits_processing ( logits , n = 1.0 ) :
M = torch . max ( logits , dim = - 1 , keepdim = True ) . values
σ = torch . std ( logits , dim = - 1 , keepdim = True )
mask = logits > = M - n * σ
n_inf = torch . full_like ( logits , - float ( " inf " ) )
logits = torch . where ( mask , logits , n_inf )
return logits
2024-11-19 16:30:05 +00:00
# perform classifier-free guidance given positive logits and negative/null logits
# some funny nonsense with needing to operate on slices since this is performed before sampling, where the logits are the entire sequence
# (and because the null logits have a shorter input sequence compared to the positive logits)
def cfg_logits ( logits , null , strength , lens , rescale = 0.0 ) :
for i , seq_len in enumerate ( lens ) :
pos = logits [ i ] [ - seq_len : ]
neg = null [ i ] [ - seq_len : ]
summed = neg + ( pos - neg ) * strength
if rescale < = 0 :
logits [ i ] [ - seq_len : ] = summed
else :
dims = tuple ( range ( 1 , summed . ndim - 1 ) )
factor = rescale * ( pos . std ( dims , keepdim = True ) / summed . std ( dims , keepdim = True ) ) + ( 1 - rescale )
logits [ i ] [ - seq_len : ] = summed * factor
2024-11-13 04:30:09 +00:00
2024-11-19 16:30:05 +00:00
return logits
2023-10-11 17:25:31 +00:00
# Credit to: https://github.com/basusourya/mirostat/
# performs mirostat-based sampling
# logits: Tensor of logit probabilities
# state: the mirostat state
def mirostat_sample ( logits , state = None ) :
def compute_k ( prob , n , tau ) :
num = 0
den = 0
for i in range ( 100 ) :
b = prob [ i ] / prob [ i + 1 ]
t = ( i + 2 ) / ( i + 1 )
num + = math . log ( b ) * math . log ( t )
den + = math . log ( t ) * * 2
s = num / den
eps = s - 1
k = ( ( eps * ( 2 * * ( tau ) ) ) / ( 1 - n * * ( - eps ) ) ) * * ( 1 / s )
k = round ( k )
return k
if " max_surprise " not in state :
state [ " max_surprise " ] = state [ " tau " ] * 2
if " error_surprise " not in state :
state [ " error_surprise " ] = 0
if " running_total_surprise " not in state :
state [ " running_total_surprise " ] = 0
sorted_logits , sorted_indices = torch . sort ( logits [ - 1 , : ] , descending = True )
prob_original = torch . softmax ( sorted_logits , dim = - 1 ) . tolist ( )
k = compute_k ( prob_original , state [ " n " ] , state [ " max_surprise " ] ) + 1
sorted_logits = sorted_logits [ 0 : k ]
sorted_indices = sorted_indices [ 0 : k ]
prob_topk = torch . softmax ( sorted_logits , dim = 0 )
prev_i = torch . multinomial ( prob_topk , num_samples = 1 , replacement = True )
state [ " index_surprise " ] = math . log2 ( 1 / prob_original [ prev_i ] )
state [ " running_total_surprise " ] + = state [ " index_surprise " ]
state [ " error_surprise " ] = state [ " index_surprise " ] - state [ " tau " ]
state [ " max_surprise " ] - = state [ " eta " ] * state [ " error_surprise " ]
state [ " token " ] = sorted_indices [ prev_i ]
2024-07-30 00:15:07 +00:00
return state
# Credits to: https://github.com/oobabooga/text-generation-webui/pull/5677
# performs DRY sampling
# * (honestly it looks close to rep pen anyways but what do I know)
# `logits` are the scores used to sample against
# `previous` are the prior tokens to penalize with
# `factor` is the scalar multiplier
# `base` is the base number to raise to the (length - allowed_length)th power
# `allowed_length` limits the range to apply DRY to
def dry_sampling ( logits , previous = None , factor = 0.0 , base = 1.75 , allowed_length = 2 ) :
if factor == 0.0 or previous is None :
return logits
lengths = { }
for i , token in enumerate ( previous ) :
length = 1
2024-07-30 01:23:26 +00:00
while length < max ( allowed_length , 50 ) :
2024-07-30 00:15:07 +00:00
j = i - length
# Start of input reached.
if j < 0 :
break
# Start of match reached.
2024-07-30 01:23:26 +00:00
if previous [ j ] != previous [ - length - 1 ] :
2024-07-30 00:15:07 +00:00
break
length + = 1
2024-07-30 01:23:26 +00:00
lengths [ token ] = max ( length , lengths [ token ] ) if token in lengths else length
2024-07-30 00:15:07 +00:00
for token , length in lengths . items ( ) :
if length < allowed_length :
break
logits [ : , token ] - = factor * base * * ( length - allowed_length )
2024-10-12 02:18:26 +00:00
return logits
LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E
# Grabbed from https://github.com/xjdr-alt/entropix/blob/main/entropix/sampler.py
2024-10-13 04:53:13 +00:00
def calculate_entropix_metrics ( logits , attentions = None , dim = - 1 , use_stats = False ) :
2024-10-12 02:18:26 +00:00
""" Calculate the entropy and varentropy of the probability distribution using logsoftmax. """
2024-10-13 04:53:13 +00:00
log_probs = F . log_softmax ( logits , dim = dim )
2024-10-12 02:18:26 +00:00
probs = torch . exp ( log_probs )
entropy = - torch . sum ( probs * log_probs , dim = dim ) / LN_2 # Convert to base-2
2024-10-13 04:53:13 +00:00
varentropy = torch . sum ( probs * ( log_probs / LN_2 + entropy . unsqueeze ( - 1 ) ) * * 2 , dim = dim )
2024-10-12 02:18:26 +00:00
2024-10-13 04:53:13 +00:00
if attentions is None :
2024-10-12 02:18:26 +00:00
return {
" logits_entropy " : torch . mean ( entropy ) . item ( ) ,
" logits_varentropy " : torch . mean ( varentropy ) . item ( ) ,
}
2024-10-13 17:01:12 +00:00
last_attention_scores = attentions [ - 1 ] . unsqueeze ( 0 ) # ( bsz, heads, seq_len, seq_len )
2024-10-13 04:53:13 +00:00
attention_probs = F . softmax ( last_attention_scores , dim = - 1 )
if use_stats :
attn_stats = AttnStats . new ( 1 , attentions . shape [ 0 ] , attentions . shape [ 1 ] , logits . device )
for idx , attn in enumerate ( attentions ) :
2024-10-13 17:01:12 +00:00
attn_stats . update ( attn . unsqueeze ( 0 ) [ : , : , - 1 , : ] , idx ) # (bsz, heads, last_token, seq_len)
2024-10-13 04:53:13 +00:00
attn_entropy = attn_stats . entropy
attn_varentropy = attn_stats . varentropy
else :
attn_entropy = - torch . sum ( attention_probs * torch . log2 ( torch . clamp ( attention_probs , 1e-10 , 1.0 ) ) , dim = - 1 )
attn_varentropy = torch . var ( attn_entropy , dim = 1 )
2024-10-12 02:18:26 +00:00
2024-10-13 04:53:13 +00:00
# Add a small epsilon to avoid NaN when all values are the same
attn_varentropy = torch . where ( torch . isnan ( attn_varentropy ) , torch . zeros_like ( attn_varentropy ) , attn_varentropy )
2024-10-12 02:18:26 +00:00
mean_attention = torch . mean ( attention_probs , dim = 1 )
2024-10-13 04:53:13 +00:00
agreement = torch . mean ( torch . abs ( attention_probs - mean_attention . unsqueeze ( 1 ) ) , dim = ( 1 , 2 ) )
interaction_strength = torch . mean ( torch . abs ( last_attention_scores ) , dim = ( 1 , 2 , 3 ) )
2024-10-12 02:18:26 +00:00
return {
2024-10-12 14:57:34 +00:00
" logits_entropy " : torch . mean ( entropy ) . item ( ) ,
" logits_varentropy " : torch . mean ( varentropy ) . item ( ) ,
" attn_entropy " : torch . mean ( attn_entropy ) . item ( ) ,
" attn_varentropy " : torch . mean ( attn_varentropy ) . item ( ) ,
" agreement " : torch . mean ( agreement ) . item ( ) ,
2024-10-13 04:53:13 +00:00
" interaction_strength " : interaction_strength . item ( ) , # torch.mean(interaction_strength).item(),
2024-10-12 14:46:18 +00:00
" action " : - 1
2024-10-12 02:18:26 +00:00
}
2024-10-13 04:53:13 +00:00
from typing import NamedTuple
class AttnStats ( NamedTuple ) :
entropy : torch . Tensor # (bsz, n_layers, num_heads)
varentropy : torch . Tensor # (bsz, n_layers, num_heads)
n_layers : int
n_heads : int
@classmethod
def new ( cls , bsz : int , n_layers : int , n_heads : int , device = " cuda " ) - > ' AttnStats ' :
return cls (
entropy = torch . zeros ( ( bsz , n_layers , n_heads ) , dtype = torch . float32 , device = device ) ,
varentropy = torch . zeros ( ( bsz , n_layers , n_heads ) , dtype = torch . float32 , device = device ) ,
n_layers = n_layers ,
n_heads = n_heads
)
@property
def avg_entropy ( self ) :
return self . entropy . sum ( dim = - 1 , keepdim = False ) # Average across heads
@property
def avg_varentropy ( self ) :
return self . varentropy . sum ( dim = - 1 , keepdim = False ) # Average across heads
@property
def std_error ( self ) :
return torch . sqrt ( torch . mean ( self . varentropy ) ) / ( self . n_heads * self . n_layers )
def update ( self , scores : torch . Tensor , layer_idx : int ) :
# scores shape: (bsz, n_heads, seqlen, n_words)
probs = torch . nn . functional . softmax ( scores , dim = - 1 )
new_entropy = - torch . sum ( torch . where ( probs > 0 , probs * torch . log ( probs ) , torch . tensor ( 0.0 ) ) , dim = - 1 )
new_varentropy = torch . sum ( probs * ( torch . log ( probs ) + new_entropy . unsqueeze ( - 1 ) ) * * 2 , dim = - 1 )
# Update entropy and varentropy tensors
self . entropy [ : , layer_idx , : ] = new_entropy
self . varentropy [ : , layer_idx , : ] = new_varentropy
return self
2024-10-12 02:18:26 +00:00
# to-do: play around with these values
@dataclass ( )
class EntropixSamplerConfig :
2024-10-12 14:57:34 +00:00
temp : float = 0.666
2024-10-12 03:36:06 +00:00
top_p : float = 0.90
top_k : int = 27
min_p : float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth
low_ent_thresh : float = 0.1 # 3.0
low_vent_thresh : float = 0.1 # 3.0
med_ent_thresh : float = 3.0 # 6.0
high_ent_thresh : float = 5.0 # 9.0
high_vent_thresh : float = 5.0 # 9.0
# TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible
helv_attn_ent_offset : float = 1.3
helv_attn_ent_coef : float = 0.2
lehv_interaction_strength_offset : float = 1.2
lehv_interaction_strength_coef : float = 0.3
hehv_attn_ent_coef : float = 0.2
hehv_attn_vent_offset : float = 2.0
hehv_attn_vent_coef : float = 0.5
# TODO not convinced this should
n_adaptive_samples : int = 5
# Adaptive sampling parameters
ada_temp_logits : float = 0.3
ada_temp_attn : float = 0.2
ada_temp_agree : float = 0.2
ada_top_p : float = 0.1
ada_top_k_int : float = 0.3
ada_top_k_agree : float = 0.2
ada_min_p : float = 0.5
ada_score_logits_ent : float = 0.1
ada_score_attn_ent : float = 0.2
ada_score_logits_vent : float = 0.3
ada_score_attn_vent : float = 0.4
ada_score_agree : float = 0.5
ada_score_int : float = 0.6
# extra stuff
2024-10-12 14:46:18 +00:00
temperature_max : float = 1.25
temperature_min : float = 0.5
2024-10-12 03:36:06 +00:00
top_k_min : int = 1
top_k_max : int = 1024
2024-10-12 14:46:18 +00:00
top_p_min : int = 0.1
top_p_max : int = 1.0
min_p_min : int = 0.01
min_p_max : int = 0.5
Exponential = torch . distributions . exponential . Exponential ( 1.0 )
2024-10-12 14:57:34 +00:00
# Doing as close to the original sampling method just to reduce variance
2024-10-12 14:46:18 +00:00
def _sample_entropix (
logits ,
temperature = 1.0 ,
top_k = 0 ,
top_p = 1.0 ,
min_p = 0.0 ,
cfg = EntropixSamplerConfig ( ) ,
) :
if top_k == 0 :
top_k = logits . shape [ - 1 ]
2024-10-13 17:01:12 +00:00
logit = logits [ - 1 , : ]
2024-10-12 14:46:18 +00:00
temperature = clamp ( float ( temperature ) , cfg . temperature_min , cfg . temperature_max )
top_p = clamp ( float ( top_p ) , cfg . top_p_min , cfg . top_p_max )
top_k = clamp ( int ( top_k ) , cfg . top_k_min , cfg . top_k_max )
min_p = clamp ( float ( min_p ) , cfg . min_p_min , cfg . min_p_max )
2024-10-13 17:01:12 +00:00
probs = F . softmax ( logit / temperature , dim = - 1 )
2024-10-12 14:46:18 +00:00
# Apply min_p sampling
if min_p > 0.0 :
2024-10-13 04:53:13 +00:00
p_max = float ( torch . max ( probs , dim = - 1 , keepdim = True ) . values )
2024-10-12 14:46:18 +00:00
indices_to_remove = probs < ( min_p * p_max )
2024-10-13 17:01:12 +00:00
logit = torch . where ( indices_to_remove , torch . full_like ( logit , float ( ' -inf ' ) ) , logit )
2024-10-12 14:46:18 +00:00
# Apply top-k sampling
2024-10-13 04:53:13 +00:00
top_k_probs , top_k_indices = torch . topk ( probs , k = min ( top_k , probs . shape [ - 1 ] ) )
2024-10-12 14:46:18 +00:00
probs_sort = torch . flip ( top_k_probs , dims = [ - 1 ] )
probs_idx = torch . flip ( top_k_indices , dims = [ - 1 ] )
probs_sum = torch . cumsum ( probs_sort , dim = - 1 )
# Apply top-p sampling
2024-10-13 17:01:12 +00:00
mask = torch . where ( probs_sum - probs_sort > top_p , torch . tensor ( 1.0 , device = logit . device ) , torch . tensor ( 0.0 , device = logit . device ) )
2024-10-12 14:46:18 +00:00
probs_sort = probs_sort * ( 1 - mask )
2024-10-13 04:53:13 +00:00
probs_sort = probs_sort / torch . sum ( probs_sort , dim = - 1 , keepdim = True )
2024-10-12 14:46:18 +00:00
2024-10-13 04:53:13 +00:00
q = Exponential . sample ( probs_sort . shape )
"""
# q = torch.rand(probs_sort.shape, generator=generator, device=probs_sort.device)
"""
next_token = torch . argmax ( probs_sort / q , dim = - 1 , keepdim = True )
2024-10-13 17:01:12 +00:00
next_token_g = torch . take_along_dim ( probs_idx , next_token , dim = - 1 )
2024-10-13 04:53:13 +00:00
return next_token_g
2024-10-12 14:46:18 +00:00
def sample_entropix (
logits ,
attentions ,
temperature = 1.0 ,
2024-10-12 15:05:47 +00:00
top_k = 27 ,
2024-10-12 14:46:18 +00:00
top_p = 1.0 ,
min_p = 0.0 ,
cfg = EntropixSamplerConfig ( ) ,
2024-11-02 16:49:05 +00:00
metrics_only = False ,
2024-10-12 14:46:18 +00:00
) :
2024-10-12 15:05:47 +00:00
"""
2024-10-12 14:57:34 +00:00
temperature = cfg . temp
top_k = cfg . top_k
top_p = cfg . top_p
2024-10-12 15:05:47 +00:00
"""
2024-10-12 14:57:34 +00:00
2024-10-13 17:01:12 +00:00
# logits: ( seq_len, vocab )
# attentions: ( layer, heads, seq_len, seq_len )
metrics = calculate_entropix_metrics ( logits [ - 1 : , : ] , attentions [ : , : , - 1 : , : ] )
2024-10-12 14:46:18 +00:00
ent , vent = metrics [ " logits_entropy " ] , metrics [ " logits_varentropy " ]
attn_ent , attn_vent = metrics [ " attn_entropy " ] , metrics [ " attn_varentropy " ]
agreement = metrics [ " agreement " ]
interaction_strength = metrics [ " interaction_strength " ]
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
if ent < cfg . low_ent_thresh and vent < cfg . low_vent_thresh :
metrics [ " action " ] = 0
2024-10-13 17:01:12 +00:00
res = logits [ - 1 , : ] . argmax ( dim = 1 )
2024-10-12 14:46:18 +00:00
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
elif ent > cfg . high_ent_thresh and vent < cfg . low_vent_thresh :
metrics [ " action " ] = 1
# sample with slightly higher temperature
temperature * = cfg . helv_attn_ent_offset + cfg . helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy
res = _sample_entropix ( logits , temperature , top_k , top_p , min_p , cfg = cfg )
# Low Entropy, High Varentropy: "exploring forks in the path"
elif ent < cfg . high_ent_thresh and vent > cfg . high_vent_thresh :
metrics [ " action " ] = 2
temperature * = cfg . lehv_interaction_strength_offset + cfg . lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength
top_k = max ( 5 , int ( top_k * ( 1 + 0.5 * ( 1 - agreement ) ) ) ) # Increase top_k when agreement is low
res = _sample_entropix ( logits , temperature , top_k , top_p , min_p , cfg = cfg )
# High Entropy, High Varentropy: "resampling in the mist"
elif ent > cfg . med_ent_thresh and vent > cfg . high_vent_thresh :
metrics [ " action " ] = 3
# Use high temperature and adjusted top_p based on attention metrics
temperature * = cfg . hehv_attn_vent_offset + cfg . hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy
top_p = max ( 0.5 , top_p - cfg . hehv_attn_ent_coef * attn_ent ) # Decrease top_p when attention entropy is high
res = _sample_entropix ( logits , temperature , top_k , top_p , min_p , cfg = cfg )
# Middle ground: use adaptive sampling
else :
metrics [ " action " ] = 4
2024-10-13 04:53:13 +00:00
log_softmax = F . log_softmax ( logits , dim = - 1 )
2024-10-12 14:46:18 +00:00
logits_uncertainty = ent + vent
attn_uncertainty = attn_ent + attn_vent
2024-10-12 14:57:34 +00:00
temperature * = 1 + cfg . ada_temp_logits * logits_uncertainty + cfg . ada_temp_attn * attn_uncertainty - cfg . ada_temp_agree * agreement
top_p = top_p * ( 1 + cfg . ada_top_p * attn_vent )
top_k = round ( float ( top_k * ( 1 + cfg . ada_top_k_int * interaction_strength - cfg . ada_top_k_agree * agreement ) ) )
min_p = cfg . min_p * ( 1 - cfg . ada_min_p * logits_uncertainty )
2024-10-12 14:46:18 +00:00
samples = [ _sample_entropix ( logits . clone ( ) , temperature , top_k , top_p , min_p , cfg = cfg ) for _ in range ( cfg . n_adaptive_samples ) ]
def score_sample ( sample ) :
2024-10-13 04:53:13 +00:00
one_hot = F . one_hot ( sample , logits . shape [ - 1 ] )
2024-10-12 14:46:18 +00:00
log_prob = torch . sum ( log_softmax * one_hot )
confidence_score = (
( 1 - ent ) * cfg . ada_score_logits_ent +
( 1 - attn_ent ) * cfg . ada_score_attn_ent +
( 1 - vent ) * cfg . ada_score_logits_vent +
( 1 - attn_vent ) * cfg . ada_score_attn_vent +
agreement * cfg . ada_score_agree +
interaction_strength * cfg . ada_score_int
)
2024-10-12 15:41:35 +00:00
"""
if 1024 in sample :
return 1000
"""
2024-10-12 14:46:18 +00:00
return log_prob + confidence_score
sample_scores = [ score_sample ( sample ) for sample in samples ]
best_sample_idx = torch . argmax ( torch . asarray ( sample_scores ) )
res = samples [ best_sample_idx ]
2024-10-13 17:01:12 +00:00
"""
metrics = {
" attn_entropy " : metrics [ " attn_entropy " ] ,
" attn_varentropy " : metrics [ " attn_varentropy " ] ,
}
"""
2024-10-12 14:46:18 +00:00
"""
metrics [ " temperature " ] = temperature
metrics [ " top_k " ] = top_k
metrics [ " top_p " ] = top_p
metrics [ " min_p " ] = min_p
"""
2024-11-09 04:05:41 +00:00
return res , metrics