reverted rep pen sampler due to a regression
This commit is contained in:
parent
b1f4db39c8
commit
b1df6a7bed
|
@ -15,6 +15,31 @@ from .utils import clamp
|
||||||
# `decay` is a factor that will exponentially apply to how far away it is
|
# `decay` is a factor that will exponentially apply to how far away it is
|
||||||
|
|
||||||
# this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future)
|
# this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future)
|
||||||
|
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
|
||||||
|
if factor == 1.0 or previous is None:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
unique = set()
|
||||||
|
priors = reversed(previous)
|
||||||
|
for distance, token in enumerate(priors):
|
||||||
|
# rep-pen range
|
||||||
|
if limit and distance >= limit:
|
||||||
|
continue
|
||||||
|
# skip if we're only applying the decay once
|
||||||
|
if one_time and token in unique:
|
||||||
|
continue
|
||||||
|
|
||||||
|
distance += 1
|
||||||
|
logits[:, token] /= factor * (distance ** decay)
|
||||||
|
|
||||||
|
# add to set if we care about it
|
||||||
|
if one_time:
|
||||||
|
unique.add(token)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
"""
|
||||||
|
# I do not know why this is a regression...
|
||||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
|
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
|
||||||
if factor == 1.0 or previous is None:
|
if factor == 1.0 or previous is None:
|
||||||
return logits
|
return logits
|
||||||
|
@ -55,6 +80,7 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F
|
||||||
|
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
"""
|
||||||
|
|
||||||
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
|
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
|
||||||
# `length` is the length of the sequence currently
|
# `length` is the length of the sequence currently
|
||||||
|
|
Loading…
Reference in New Issue
Block a user