From b1df6a7bed23f76a5653241107a200b4723c8ccb Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 11 Nov 2024 20:35:08 -0600 Subject: [PATCH] reverted rep pen sampler due to a regression --- vall_e/samplers.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/vall_e/samplers.py b/vall_e/samplers.py index ae4ef84..1d06d1a 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -15,6 +15,31 @@ from .utils import clamp # `decay` is a factor that will exponentially apply to how far away it is # this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future) +def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ): + if factor == 1.0 or previous is None: + return logits + + unique = set() + priors = reversed(previous) + for distance, token in enumerate(priors): + # rep-pen range + if limit and distance >= limit: + continue + # skip if we're only applying the decay once + if one_time and token in unique: + continue + + distance += 1 + logits[:, token] /= factor * (distance ** decay) + + # add to set if we care about it + if one_time: + unique.add(token) + + return logits + +""" +# I do not know why this is a regression... def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ): if factor == 1.0 or previous is None: return logits @@ -55,6 +80,7 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F return logits +""" # Simple "filter" that modifies the logit for the stop token, based on the sequence length # `length` is the length of the sequence currently