added a length-based decay factor for repetition penalty
This commit is contained in:
parent
b922f35b6b
commit
10c34c5b98
|
@ -132,6 +132,7 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
|||
* `--top-p`: limits the sampling pool to top sum of values that equal `P`% probability in the probability distribution.
|
||||
* `--top-k`: limits the sampling pool to the top `K` values in the probability distribution.
|
||||
* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.
|
||||
* `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence.
|
||||
* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky.
|
||||
|
||||
## To-Do
|
||||
|
|
|
@ -14,20 +14,23 @@ def main():
|
|||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--ar-ckpt", type=Path, default=None)
|
||||
parser.add_argument("--nar-ckpt", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
|
||||
|
||||
parser.add_argument("--ar-temp", type=float, default=1.0)
|
||||
parser.add_argument("--nar-temp", type=float, default=1.0)
|
||||
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
|
||||
parser.add_argument("--device", default="cuda")
|
||||
args = parser.parse_args()
|
||||
|
||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
|
||||
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, length_penalty=args.length_penalty )
|
||||
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty )
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -125,7 +125,7 @@ class TTS():
|
|||
return res
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, length_penalty=0.0, out_path=None ):
|
||||
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, repetition_penalty_decay=0.0, length_penalty=0.0, out_path=None ):
|
||||
if out_path is None:
|
||||
out_path = f"./data/{cfg.start_time}.wav"
|
||||
|
||||
|
@ -136,9 +136,9 @@ class TTS():
|
|||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
|
||||
with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty)
|
||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty)
|
||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
|
||||
|
||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path)
|
||||
|
||||
|
|
|
@ -89,6 +89,7 @@ class AR(Base):
|
|||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
):
|
||||
if resps_list is not None:
|
||||
|
@ -128,6 +129,7 @@ class AR(Base):
|
|||
sampling_top_p=sampling_top_p,
|
||||
sampling_top_k=sampling_top_k,
|
||||
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
sampling_length_penalty=sampling_length_penalty,
|
||||
state=state
|
||||
)
|
||||
|
|
|
@ -74,6 +74,7 @@ class AR_NAR(Base):
|
|||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
):
|
||||
device = text_list[0].device
|
||||
|
@ -127,6 +128,7 @@ class AR_NAR(Base):
|
|||
sampling_top_p=sampling_top_p,
|
||||
sampling_top_k=sampling_top_k,
|
||||
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
sampling_length_penalty=sampling_length_penalty,
|
||||
)
|
||||
|
||||
|
@ -157,6 +159,7 @@ class AR_NAR(Base):
|
|||
sampling_top_p=sampling_top_p,
|
||||
sampling_top_k=sampling_top_k,
|
||||
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
sampling_length_penalty=sampling_length_penalty,
|
||||
state=state
|
||||
)
|
||||
|
|
|
@ -50,19 +50,32 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
|
|||
return x, m
|
||||
|
||||
# Simple filter to modify a token's probability if it shows up in the past
|
||||
# To-do: have its effect decay based on distance
|
||||
def reptition_penalize( logits, previous, factor=1.0 ):
|
||||
# `one_time` will only apply the penalty once
|
||||
# `decay` is a factor that will exponentially apply to how far away it is
|
||||
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ):
|
||||
if factor == 1.0:
|
||||
return logits
|
||||
|
||||
priors = set(previous.tolist())
|
||||
unique = set()
|
||||
priors = reversed(previous.tolist())
|
||||
for distance, token in enumerate(priors):
|
||||
# skip if we're only applying the decay once
|
||||
if one_time and token in unique:
|
||||
continue
|
||||
|
||||
for token in priors:
|
||||
logits[:, token] /= factor
|
||||
distance += 1
|
||||
logits[:, token] /= factor * (distance ** decay)
|
||||
|
||||
# add to set if we care about it
|
||||
if one_time:
|
||||
unique.add(token)
|
||||
|
||||
return logits
|
||||
|
||||
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
|
||||
# `length` is the length of the sequence currently
|
||||
# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences
|
||||
# `token` is the stop token.
|
||||
def length_penalize( logits, length, factor=0.0, token=-1 ):
|
||||
if factor == 0.0:
|
||||
return logits
|
||||
|
@ -325,6 +338,7 @@ class Base(nn.Module):
|
|||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
|
||||
state: dict | None = None,
|
||||
|
@ -423,7 +437,7 @@ class Base(nn.Module):
|
|||
logits = [ logit[-1:] for logit in logits ]
|
||||
|
||||
# perform repetition penalizing
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty) for logit, resps in zip( logits, resps_list ) ]
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty, decay=sampling_repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal:
|
||||
|
|
|
@ -60,6 +60,7 @@ class NAR(Base):
|
|||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
):
|
||||
"""
|
||||
|
@ -119,6 +120,7 @@ class NAR(Base):
|
|||
sampling_top_p=sampling_top_p,
|
||||
sampling_top_k=sampling_top_k,
|
||||
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
sampling_length_penalty=sampling_length_penalty,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user