reverted rep pen sampler due to a regression
This commit is contained in:
parent
b1f4db39c8
commit
b1df6a7bed
|
@ -15,6 +15,31 @@ from .utils import clamp
|
|||
# `decay` is a factor that will exponentially apply to how far away it is
|
||||
|
||||
# this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future)
|
||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
|
||||
if factor == 1.0 or previous is None:
|
||||
return logits
|
||||
|
||||
unique = set()
|
||||
priors = reversed(previous)
|
||||
for distance, token in enumerate(priors):
|
||||
# rep-pen range
|
||||
if limit and distance >= limit:
|
||||
continue
|
||||
# skip if we're only applying the decay once
|
||||
if one_time and token in unique:
|
||||
continue
|
||||
|
||||
distance += 1
|
||||
logits[:, token] /= factor * (distance ** decay)
|
||||
|
||||
# add to set if we care about it
|
||||
if one_time:
|
||||
unique.add(token)
|
||||
|
||||
return logits
|
||||
|
||||
"""
|
||||
# I do not know why this is a regression...
|
||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
|
||||
if factor == 1.0 or previous is None:
|
||||
return logits
|
||||
|
@ -55,6 +80,7 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F
|
|||
|
||||
|
||||
return logits
|
||||
"""
|
||||
|
||||
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
|
||||
# `length` is the length of the sequence currently
|
||||
|
|
Loading…
Reference in New Issue
Block a user