reverted rep pen sampler due to a regression

This commit is contained in:
mrq 2024-11-11 20:35:08 -06:00
parent b1f4db39c8
commit b1df6a7bed

View File

@ -15,6 +15,31 @@ from .utils import clamp
# `decay` is a factor that will exponentially apply to how far away it is # `decay` is a factor that will exponentially apply to how far away it is
# this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future) # this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future)
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
if factor == 1.0 or previous is None:
return logits
unique = set()
priors = reversed(previous)
for distance, token in enumerate(priors):
# rep-pen range
if limit and distance >= limit:
continue
# skip if we're only applying the decay once
if one_time and token in unique:
continue
distance += 1
logits[:, token] /= factor * (distance ** decay)
# add to set if we care about it
if one_time:
unique.add(token)
return logits
"""
# I do not know why this is a regression...
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ): def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
if factor == 1.0 or previous is None: if factor == 1.0 or previous is None:
return logits return logits
@ -55,6 +80,7 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F
return logits return logits
"""
# Simple "filter" that modifies the logit for the stop token, based on the sequence length # Simple "filter" that modifies the logit for the stop token, based on the sequence length
# `length` is the length of the sequence currently # `length` is the length of the sequence currently