diff --git a/vall_e/config.py b/vall_e/config.py index e78ed23..da35daa 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -427,11 +427,8 @@ class Evaluation: batch_size: int = 64 # number of samples per batch during eval / val frequency: int = 250 # do eval / val every X iterations size: int = 64 # number of samples to generate during eval / val - - steps: int = 500 - ar_temperature: float = 0.0 # AR temp for inferencing - nar_temperature: float = 0.0 # NAR temp for inferencing - nar_levels: int = 0 # maximum NAR levels to use for inferencing + ar_kwargs: dict = field(default_factory=lambda: {}) # inferencing kwargs + nar_kwargs: dict = field(default_factory=lambda: {}) # inferencing kwargs @dataclass() class DeepSpeed: diff --git a/vall_e/demo.py b/vall_e/demo.py index 1865cca..9d46e5d 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -68,7 +68,7 @@ def main(): parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=0) parser.add_argument("--min-p", type=float, default=0.0) - parser.add_argument("--repetition-penalty", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.125) parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0) parser.add_argument("--beam-width", type=int, default=0) @@ -122,7 +122,9 @@ def main(): comparison_kwargs["titles"] = ["LoRA", "No LoRA"] comparison_kwargs["disabled"]["use_lora"] = True + comparison_kwargs["disabled"]["ar_temp"] = 0.0 comparison_kwargs["enabled"]["use_lora"] = False + comparison_kwargs["enabled"]["ar_temp"] = 0.95 elif args.comparison == "entropix-sampling": comparison_kwargs["suffix"] = "entropix_sampling" comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index b29056e..7e81818 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -375,6 +375,8 @@ class AR_NAR(Base): text_list = text_list * sampling_beam_width proms_list = proms_list * sampling_beam_width sequence_list = sequence_list * sampling_beam_width + task_list = task_list * sampling_beam_width + start_slice = start_slice * sampling_beam_width stopped = torch.zeros(batch_size, device=device).bool() scores = [ scores[i] + score for i, score in enumerate(scores) ] @@ -399,7 +401,8 @@ class AR_NAR(Base): # pick the best scoring candidate # desu this is always going to be candidate 0 if sampling_beam_width: - sequence_list = [ sequence_list[0] ] + sequence_list = sequence_list[:1] + task_list = task_list[:1] # remove stop token sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)] diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 7c4794a..fc55291 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -11,13 +11,16 @@ from dataclasses import asdict, dataclass, field # Simple filter to modify a token's probability if it shows up in the past # `one_time` will only apply the penalty once # `decay` is a factor that will exponentially apply to how far away it is -def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False ): +def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False, limit=75 ): if factor == 1.0 or previous is None: return logits unique = set() priors = reversed(previous) for distance, token in enumerate(priors): + # rep-pen range + if limit and distance >= limit: + continue # skip if we're only applying the decay once if one_time and token in unique: continue diff --git a/vall_e/train.py b/vall_e/train.py index 9f0b969..1fc003e 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -133,7 +133,7 @@ def run_eval(engines, eval_name, dl, args=None): engine = engines[name] - kwargs = dict( + base_kwargs = dict( text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], @@ -141,19 +141,23 @@ def run_eval(engines, eval_name, dl, args=None): ) if engine.hyper_config.experimental.hf: - resps_list = engine( **kwargs ) + resps_list = engine( **base_kwargs ) elif "len" in engine.hyper_config.capabilities: - len_list = engine( **kwargs, max_steps=10 ) # don't need more than that + len_list = engine( **base_kwargs, max_steps=10 ) # don't need more than that len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ] - resps_list = engine( **kwargs, len_list=len_list, max_levels=cfg.evaluation.nar_levels ) + + kwargs = base_kwargs | cfg.evaluation.nar_kwargs + resps_list = engine( **kwargs, len_list=len_list ) else: if "ar" in engine.hyper_config.capabilities: - resps_list = engine( **kwargs, max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) + kwargs = base_kwargs | cfg.evaluation.ar_kwargs + resps_list = engine( **kwargs ) else: resps_list = [ resp[:, 0] for resp in batch["resps"] ] if "nar" in engine.hyper_config.capabilities: - resps_list = engine( **kwargs, resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels ) + kwargs = base_kwargs | cfg.evaluation.nar_kwargs + resps_list = engine( **kwargs, resps_list=resps_list ) process( name, batch, resps_list ) diff --git a/vall_e/webui.py b/vall_e/webui.py index 6797911..2bde04b 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -203,6 +203,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, + beam_width=args.beam_width, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty,