used torch.max fixes things, somehow, for dynamic temp sampling

This commit is contained in:
mrq 2023-10-10 16:42:24 -05:00
parent 87db03dd93
commit ec25f56bd9

View File

@ -120,23 +120,20 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"
return logits
# credit to https://github.com/LostRuins/koboldcpp/pull/464
def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.00390625, k = 10, sigmoidCenterPoint = 0.5 ):
def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.0, k = 10, sigmoidCenterPoint = 0.5 ):
# loop over logits[:], as the NAR will have logits.shape[0] > 1
for i in range(logits.shape[0]):
maximum = 0.0
for logit in logits[i]:
maximum = max( maximum, logit )
sum_exp = 0.0
maximum = torch.max( logits[i] )
for logit in logits[i]:
sum_exp += math.exp( logit - maximum )
prob_max_token_before_temp = 1.0 / sum_exp
dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint)))
#print( "sum_exp:", sum_exp )
#print( "prob_max_token_before_temp:", prob_max_token_before_temp )
#print( "dynamic temperature:", dynamic_temperature )
#print( i, "sum_exp:", sum_exp )
#print( i, "prob_max_token_before_temp:", prob_max_token_before_temp )
#print( i, "dynamic temperature:", dynamic_temperature )
logits[i] /= dynamic_temperature