added a length-based decay factor for repetition penalty

This commit is contained in:
mrq 2023-09-08 21:02:00 -05:00
parent b922f35b6b
commit 10c34c5b98
7 changed files with 35 additions and 10 deletions

View File

@ -132,6 +132,7 @@ And some experimental sampling flags you can use too (your mileage will ***defin
* `--top-p`: limits the sampling pool to top sum of values that equal `P`% probability in the probability distribution.
* `--top-k`: limits the sampling pool to the top `K` values in the probability distribution.
* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.
* `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence.
* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky.
## To-Do

View File

@ -14,20 +14,23 @@ def main():
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--ar-ckpt", type=Path, default=None)
parser.add_argument("--nar-ckpt", type=Path, default=None)
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
parser.add_argument("--ar-temp", type=float, default=1.0)
parser.add_argument("--nar-temp", type=float, default=1.0)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--device", default="cuda")
args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, length_penalty=args.length_penalty )
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty )
if __name__ == "__main__":
main()

View File

@ -125,7 +125,7 @@ class TTS():
return res
@torch.inference_mode()
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, length_penalty=0.0, out_path=None ):
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, repetition_penalty_decay=0.0, length_penalty=0.0, out_path=None ):
if out_path is None:
out_path = f"./data/{cfg.start_time}.wav"
@ -136,9 +136,9 @@ class TTS():
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty)
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty)
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
wav, sr = qnt.decode_to_file(resps_list[0], out_path)

View File

@ -89,6 +89,7 @@ class AR(Base):
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
):
if resps_list is not None:
@ -128,6 +129,7 @@ class AR(Base):
sampling_top_p=sampling_top_p,
sampling_top_k=sampling_top_k,
sampling_repetition_penalty=sampling_repetition_penalty,
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
sampling_length_penalty=sampling_length_penalty,
state=state
)

View File

@ -74,6 +74,7 @@ class AR_NAR(Base):
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
):
device = text_list[0].device
@ -127,6 +128,7 @@ class AR_NAR(Base):
sampling_top_p=sampling_top_p,
sampling_top_k=sampling_top_k,
sampling_repetition_penalty=sampling_repetition_penalty,
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
sampling_length_penalty=sampling_length_penalty,
)
@ -157,6 +159,7 @@ class AR_NAR(Base):
sampling_top_p=sampling_top_p,
sampling_top_k=sampling_top_k,
sampling_repetition_penalty=sampling_repetition_penalty,
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
sampling_length_penalty=sampling_length_penalty,
state=state
)

View File

@ -50,19 +50,32 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
return x, m
# Simple filter to modify a token's probability if it shows up in the past
# To-do: have its effect decay based on distance
def reptition_penalize( logits, previous, factor=1.0 ):
# `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ):
if factor == 1.0:
return logits
priors = set(previous.tolist())
unique = set()
priors = reversed(previous.tolist())
for distance, token in enumerate(priors):
# skip if we're only applying the decay once
if one_time and token in unique:
continue
for token in priors:
logits[:, token] /= factor
distance += 1
logits[:, token] /= factor * (distance ** decay)
# add to set if we care about it
if one_time:
unique.add(token)
return logits
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
# `length` is the length of the sequence currently
# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences
# `token` is the stop token.
def length_penalize( logits, length, factor=0.0, token=-1 ):
if factor == 0.0:
return logits
@ -325,6 +338,7 @@ class Base(nn.Module):
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
state: dict | None = None,
@ -423,7 +437,7 @@ class Base(nn.Module):
logits = [ logit[-1:] for logit in logits ]
# perform repetition penalizing
logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty) for logit, resps in zip( logits, resps_list ) ]
logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty, decay=sampling_repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal:

View File

@ -60,6 +60,7 @@ class NAR(Base):
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
):
"""
@ -119,6 +120,7 @@ class NAR(Base):
sampling_top_p=sampling_top_p,
sampling_top_k=sampling_top_k,
sampling_repetition_penalty=sampling_repetition_penalty,
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
sampling_length_penalty=sampling_length_penalty,
)