diff --git a/README.md b/README.md index e4f0884..830a95f 100755 --- a/README.md +++ b/README.md @@ -132,6 +132,7 @@ And some experimental sampling flags you can use too (your mileage will ***defin * `--top-p`: limits the sampling pool to top sum of values that equal `P`% probability in the probability distribution. * `--top-k`: limits the sampling pool to the top `K` values in the probability distribution. * `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use. +* `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence. * `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky. ## To-Do diff --git a/vall_e/__main__.py b/vall_e/__main__.py index fab6a28..834afe7 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -14,20 +14,23 @@ def main(): parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--ar-ckpt", type=Path, default=None) parser.add_argument("--nar-ckpt", type=Path, default=None) + parser.add_argument("--max-ar-steps", type=int, default=6 * 75) + parser.add_argument("--ar-temp", type=float, default=1.0) parser.add_argument("--nar-temp", type=float, default=1.0) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=0) parser.add_argument("--repetition-penalty", type=float, default=1.0) + parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0) parser.add_argument("--device", default="cuda") args = parser.parse_args() tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device ) - tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, length_penalty=args.length_penalty ) + tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty ) if __name__ == "__main__": main() diff --git a/vall_e/inference.py b/vall_e/inference.py index 7d20ea1..707c889 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -125,7 +125,7 @@ class TTS(): return res @torch.inference_mode() - def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, length_penalty=0.0, out_path=None ): + def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, repetition_penalty_decay=0.0, length_penalty=0.0, out_path=None ): if out_path is None: out_path = f"./data/{cfg.start_time}.wav" @@ -136,9 +136,9 @@ class TTS(): phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16) with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp): - resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty) + resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty) resps_list = [r.unsqueeze(-1) for r in resps_list] - resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty) + resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty) wav, sr = qnt.decode_to_file(resps_list[0], out_path) diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 95adbea..8e56345 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -89,6 +89,7 @@ class AR(Base): sampling_top_k: int = -100, sampling_top_p: float = 1.0, sampling_repetition_penalty: float = 1.0, + sampling_repetition_penalty_decay: float = 0.0, sampling_length_penalty: float = 0.0, ): if resps_list is not None: @@ -128,6 +129,7 @@ class AR(Base): sampling_top_p=sampling_top_p, sampling_top_k=sampling_top_k, sampling_repetition_penalty=sampling_repetition_penalty, + sampling_repetition_penalty_decay=sampling_repetition_penalty_decay, sampling_length_penalty=sampling_length_penalty, state=state ) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 3caa115..26183f0 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -74,6 +74,7 @@ class AR_NAR(Base): sampling_top_k: int = -100, sampling_top_p: float = 1.0, sampling_repetition_penalty: float = 1.0, + sampling_repetition_penalty_decay: float = 0.0, sampling_length_penalty: float = 0.0, ): device = text_list[0].device @@ -127,6 +128,7 @@ class AR_NAR(Base): sampling_top_p=sampling_top_p, sampling_top_k=sampling_top_k, sampling_repetition_penalty=sampling_repetition_penalty, + sampling_repetition_penalty_decay=sampling_repetition_penalty_decay, sampling_length_penalty=sampling_length_penalty, ) @@ -157,6 +159,7 @@ class AR_NAR(Base): sampling_top_p=sampling_top_p, sampling_top_k=sampling_top_k, sampling_repetition_penalty=sampling_repetition_penalty, + sampling_repetition_penalty_decay=sampling_repetition_penalty_decay, sampling_length_penalty=sampling_length_penalty, state=state ) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 9e6534c..fc8e646 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -50,19 +50,32 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"): return x, m # Simple filter to modify a token's probability if it shows up in the past -# To-do: have its effect decay based on distance -def reptition_penalize( logits, previous, factor=1.0 ): +# `one_time` will only apply the penalty once +# `decay` is a factor that will exponentially apply to how far away it is +def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ): if factor == 1.0: return logits - priors = set(previous.tolist()) + unique = set() + priors = reversed(previous.tolist()) + for distance, token in enumerate(priors): + # skip if we're only applying the decay once + if one_time and token in unique: + continue - for token in priors: - logits[:, token] /= factor + distance += 1 + logits[:, token] /= factor * (distance ** decay) + + # add to set if we care about it + if one_time: + unique.add(token) return logits # Simple "filter" that modifies the logit for the stop token, based on the sequence length +# `length` is the length of the sequence currently +# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences +# `token` is the stop token. def length_penalize( logits, length, factor=0.0, token=-1 ): if factor == 0.0: return logits @@ -325,6 +338,7 @@ class Base(nn.Module): sampling_top_k: int = -100, sampling_top_p: float = 1.0, sampling_repetition_penalty: float = 1.0, + sampling_repetition_penalty_decay: float = 0.0, sampling_length_penalty: float = 0.0, state: dict | None = None, @@ -423,7 +437,7 @@ class Base(nn.Module): logits = [ logit[-1:] for logit in logits ] # perform repetition penalizing - logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty) for logit, resps in zip( logits, resps_list ) ] + logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty, decay=sampling_repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] # (AR) perform length penalizing if quant_levels is None and self.causal: diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index fde734d..37d06eb 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -60,6 +60,7 @@ class NAR(Base): sampling_top_k: int = -100, sampling_top_p: float = 1.0, sampling_repetition_penalty: float = 1.0, + sampling_repetition_penalty_decay: float = 0.0, sampling_length_penalty: float = 0.0, ): """ @@ -119,6 +120,7 @@ class NAR(Base): sampling_top_p=sampling_top_p, sampling_top_k=sampling_top_k, sampling_repetition_penalty=sampling_repetition_penalty, + sampling_repetition_penalty_decay=sampling_repetition_penalty_decay, sampling_length_penalty=sampling_length_penalty, )