diff --git a/README.md b/README.md index e029b91..00f0ff2 100755 --- a/README.md +++ b/README.md @@ -139,8 +139,8 @@ And some experimental sampling flags you can use too (your mileage will ***defin * `--top-k`: limits the sampling pool to the top `K` values in the probability distribution. * `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use. * `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence. -* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky. - +* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky due to the AR already being well correlated with the length. +* `--beam-width`: (AR only) specifies the number of branches to search through for beam sampling. This is a very naive implementation that's effectively just greedy sampling across `B` spaces. ## To-Do * reduce load time for creating / preparing dataloaders (hint: remove use of `Path.glob` and `Path.rglob`). diff --git a/vall_e/__main__.py b/vall_e/__main__.py index defe698..7e3ff44 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -27,6 +27,7 @@ def main(): parser.add_argument("--repetition-penalty", type=float, default=1.0) parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0) + parser.add_argument("--beam-width", type=int, default=0) parser.add_argument("--device", type=str, default=None) parser.add_argument("--amp", action="store_true") @@ -34,7 +35,7 @@ def main(): args = parser.parse_args() tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp ) - tts.inference( text=args.text, references=args.references, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty ) + tts.inference( text=args.text, references=args.references, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, beam_width=args.beam_width ) if __name__ == "__main__": main() diff --git a/vall_e/inference.py b/vall_e/inference.py index aed22d0..2d4b2f0 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -139,7 +139,23 @@ class TTS(): return res @torch.inference_mode() - def inference( self, text, references, max_ar_steps=6 * 75, max_nar_levels=7, input_prompt_length=0.0, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, repetition_penalty_decay=0.0, length_penalty=0.0, out_path=None ): + def inference( + self, + text, + references, + max_ar_steps=6 * 75, + max_nar_levels=7, + input_prompt_length=0.0, + ar_temp=0.95, + nar_temp=0.5, + top_p=1.0, + top_k=0, + repetition_penalty=1.0, + repetition_penalty_decay=0.0, + length_penalty=0.0, + beam_width=0, + out_path=None + ): if out_path is None: out_path = f"./data/{cfg.start_time}.wav" @@ -150,9 +166,9 @@ class TTS(): phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16) with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): - resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty) + resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width) resps_list = [r.unsqueeze(-1) for r in resps_list] - resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty) + resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width) wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device) diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 455197b..2b65e33 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -91,12 +91,14 @@ class AR(Base): proms_list: list[Tensor], resps_list: list[Tensor] | None = None, max_steps: int = 1000, + sampling_temperature: float = 1.0, sampling_top_k: int = -100, sampling_top_p: float = 1.0, sampling_repetition_penalty: float = 1.0, sampling_repetition_penalty_decay: float = 0.0, sampling_length_penalty: float = 0.0, + sampling_beam_width: int = 0, ): if resps_list is not None: if self.interleave: @@ -126,24 +128,39 @@ class AR(Base): for n in trange(max_steps // max(1, self.recurrent_chunk_size)): # get next in sequence - r = super().forward( + logits = super().forward( text_list=text_list, proms_list=proms_list, resps_list=self._unsqueeze_list(resps_list), quant_levels=None, - sampling_temperature=sampling_temperature, - sampling_top_p=sampling_top_p, - sampling_top_k=sampling_top_k, - sampling_repetition_penalty=sampling_repetition_penalty, - sampling_repetition_penalty_decay=sampling_repetition_penalty_decay, - sampling_length_penalty=sampling_length_penalty, state=state ) + r = super().sample( + logits=logits, + resps_list=resps_list, + + temperature=sampling_temperature, + top_p=sampling_top_p, + top_k=sampling_top_k, + repetition_penalty=sampling_repetition_penalty, + repetition_penalty_decay=sampling_repetition_penalty_decay, + length_penalty=sampling_length_penalty, + beam_width=sampling_beam_width, + ) + + # first step, expand batch + # we do it here because the sampler will already expand our logits list + if sampling_beam_width > 0 and batch_size == 1: + text_list = text_list * sampling_beam_width + proms_list = proms_list * sampling_beam_width + resps_list = resps_list * sampling_beam_width + # append tokens for i, ri in enumerate(r): if self.stop_token in ri: stopped[i] = True + resps_list[i] = torch.cat([resps_list[i], ri]) # stop token found diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index f8d2521..730b3d8 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -83,6 +83,7 @@ class AR_NAR(Base): sampling_repetition_penalty: float = 1.0, sampling_repetition_penalty_decay: float = 0.0, sampling_length_penalty: float = 0.0, + sampling_beam_width: int = 0, ): device = text_list[0].device batch_size = len(text_list) @@ -119,28 +120,33 @@ class AR_NAR(Base): quant_levels = torch.full((len(text_list),), level, device=device) - resps_list = super().forward( - text_list, - proms_list, - prev_list, + logits = super().forward( + text_list=text_list, + proms_list=proms_list, + resps_list=prev_list, quant_levels=quant_levels, - sampling_temperature=sampling_temperature, - sampling_top_p=sampling_top_p, - sampling_top_k=sampling_top_k, - sampling_repetition_penalty=sampling_repetition_penalty, - sampling_repetition_penalty_decay=sampling_repetition_penalty_decay, - sampling_length_penalty=sampling_length_penalty, ) - prev_list = [ - torch.cat([rs, r.unsqueeze(-1)], dim=-1) - for rs, r in zip(prev_list, resps_list) - ] + resps_list = super().sample( + logits=logits, + resps_list=prev_list, + quant_levels=quant_levels, + + temperature=sampling_temperature, + top_p=sampling_top_p, + top_k=sampling_top_k, + repetition_penalty=sampling_repetition_penalty, + repetition_penalty_decay=sampling_repetition_penalty_decay, + #length_penalty=sampling_length_penalty, + #beam_width=sampling_beam_width, + ) + + prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ] return prev_list # is AR - resps_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ] + sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ] stopped = torch.zeros(batch_size, device=device).bool() state = {} if cfg.inference.recurrent_forward else None @@ -151,31 +157,53 @@ class AR_NAR(Base): for n in trange(max_steps // max(1, self.recurrent_chunk_size)): # get next in sequence - r = super().forward( - text_list, - proms_list, - self._unsqueeze_list(resps_list), - sampling_temperature=sampling_temperature, - sampling_top_p=sampling_top_p, - sampling_top_k=sampling_top_k, - sampling_repetition_penalty=sampling_repetition_penalty, - sampling_repetition_penalty_decay=sampling_repetition_penalty_decay, - sampling_length_penalty=sampling_length_penalty, + resps_list = self._unsqueeze_list(sequence_list) + logits = super().forward( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + state=state ) + r = super().sample( + logits=logits, + resps_list=resps_list, + + temperature=sampling_temperature, + top_p=sampling_top_p, + top_k=sampling_top_k, + repetition_penalty=sampling_repetition_penalty, + repetition_penalty_decay=sampling_repetition_penalty_decay, + length_penalty=sampling_length_penalty, + beam_width=sampling_beam_width, + ) + + # first step, expand batch + # we do it here because the sampler will already expand our logits list + if sampling_beam_width > 0 and batch_size == 1: + batch_size *= sampling_beam_width + text_list = text_list * sampling_beam_width + proms_list = proms_list * sampling_beam_width + sequence_list = sequence_list * sampling_beam_width + stopped = torch.zeros(batch_size, device=device).bool() + # append tokens for i, ri in enumerate(r): if self.stop_token in ri: stopped[i] = True - resps_list[i] = torch.cat([resps_list[i], ri]) + sequence_list[i] = torch.cat([sequence_list[i], ri]) # stop token found stopped |= r == self.stop_token if stopped.all().item(): break - return [self._prune(r) for r in resps_list] + # pick the first candidate + if sampling_beam_width: + sequence_list = sequence_list[:1] + + return [self._prune(r) for r in sequence_list] def example_usage(): @@ -200,11 +228,9 @@ def example_usage(): qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device) text_list = [ - #torch.tensor([1, 2, 3], device=device), tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), ] proms_list = [ - #x8(torch.tensor([1, 2, 3], device=device)), qnt[:75*3, :].to(device), ] resps_list = [ @@ -232,7 +258,7 @@ def example_usage(): model = AR_NAR(**kwargs).to(device) #steps = 500 #optimizer = ml.Prodigy(model.parameters(), lr=1.0) - steps = 500 + steps = 1000 optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) engine = Engine(model=model, optimizer=optimizer) @@ -241,7 +267,7 @@ def example_usage(): @torch.inference_mode() def sample( name, steps=600 ): engine.eval() - resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) + resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95, sampling_beam_width=16 ) for i, o in enumerate(resps_list): _ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 2e9dad4..5f1022f 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -2,6 +2,7 @@ import math import torch import torch.nn.functional as F import traceback +import numpy as np from typing import Literal, overload from functools import partial @@ -53,7 +54,7 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"): # `one_time` will only apply the penalty once # `decay` is a factor that will exponentially apply to how far away it is def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ): - if factor == 1.0: + if factor == 1.0 or previous is None: return logits unique = set() @@ -115,6 +116,7 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf" # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value + return logits @@ -341,13 +343,6 @@ class Base(nn.Module): targ_list: list[Tensor] | None = None, quant_levels: Tensor | None = None, - sampling_temperature: float = 1.0, - sampling_top_k: int = -100, - sampling_top_p: float = 1.0, - sampling_repetition_penalty: float = 1.0, - sampling_repetition_penalty_decay: float = 0.0, - sampling_length_penalty: float = 0.0, - state: dict | None = None, ): x_list = self._samplewise_merge_tensors( @@ -428,8 +423,25 @@ class Base(nn.Module): precision = self.precision_metric( inputs, target ), ) - return logits + return logits + + def sample( + self, + logits: list[Tensor], + resps_list: list[Tensor], + quant_levels: Tensor | None = None, + + temperature: float = 1.0, + top_k: int = -100, + top_p: float = 1.0, + + repetition_penalty: float = 1.0, + repetition_penalty_decay: float = 0.0, + length_penalty: float = 0.0, + + beam_width: int = 0, + ): # (NAR) return the entire generated response if quant_levels is not None: logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ] @@ -441,19 +453,37 @@ class Base(nn.Module): logits = [ logit[-1:] for logit in logits ] # perform repetition penalizing - logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty, decay=sampling_repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] + logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] # (AR) perform length penalizing if quant_levels is None and self.causal: - logits = [ length_penalize(logit, length=l + 1, factor=sampling_length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ] + logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ] # scale our logits by the temp - logits = [ logit / sampling_temperature for logit in logits ] + logits = [ logit / temperature for logit in logits ] # perform top_k/top_p filtering of our logits - if sampling_top_k > 0: - logits = [ top_k_top_p_filtering(logit, top_k=sampling_top_k, top_p=sampling_top_p) for logit in logits ] - + if top_k > 0 or top_p < 1.0: + logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ] + + # do beam search (naive implementation) + # picks the top-k across all batches, and re-batches those resultant tokens + # this doesn't do any other mumbo with previous logits + # to-do: not naively implement beam searching + if beam_width > 1: + # ( batch, tokens ) => ( batch x tokens ) + flattened = torch.cat( logits ) + candidates = list(torch.topk(flattened.flatten(), beam_width).indices.tolist()) # perform top-k across all logits + for i, index in enumerate(candidates): + t = [] + N = np.prod(flattened.size()) + for n in flattened.size(): + N //= n + t.append(index // N) + index %= N + candidates[i] = tuple(t) + return [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] #, [ logits[batch] for batch, token in candidates ] + # and sample # the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax return [ Categorical(logits=logit).sample() for logit in logits ] diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index cd394e2..b0e60c6 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -98,11 +98,11 @@ class NAR(Base): quant_levels = quant_levels.to(device=device) - _ = super().forward( - text_list, - proms_list, - prev_list, - targ_list, + logits = super().forward( + text_list=text_list, + proms_list=proms_list, + resps_list=prev_list, + targ_list=targ_list, quant_levels=quant_levels, ) @@ -120,23 +120,28 @@ class NAR(Base): quant_levels = torch.full((len(text_list),), level, device=device) - resps_list = super().forward( - text_list, - proms_list, - prev_list, + logits = super().forward( + text_list=text_list, + proms_list=proms_list, + resps_list=prev_list, quant_levels=quant_levels, - sampling_temperature=sampling_temperature, - sampling_top_p=sampling_top_p, - sampling_top_k=sampling_top_k, - sampling_repetition_penalty=sampling_repetition_penalty, - sampling_repetition_penalty_decay=sampling_repetition_penalty_decay, - sampling_length_penalty=sampling_length_penalty, ) - prev_list = [ - torch.cat([rs, r.unsqueeze(-1)], dim=-1) - for rs, r in zip(prev_list, resps_list) - ] + resps_list = super().sample( + logits=logits, + resps_list=resps_list, + quant_levels=quant_levels, + + temperature=sampling_temperature, + top_p=sampling_top_p, + top_k=sampling_top_k, + repetition_penalty=sampling_repetition_penalty, + repetition_penalty_decay=sampling_repetition_penalty_decay, + #length_penalty=sampling_length_penalty, + #beam_width=sampling_beam_width, + ) + + prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ] return prev_list diff --git a/vall_e/train.py b/vall_e/train.py index c874874..f9f503d 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -46,10 +46,6 @@ def train_feeder(engine, batch): @torch.inference_mode() def run_eval(engines, disabled_engines, eval_name, dl): - engines_stats = { - 'eval': eval_name - } - AR = None NAR = None AR_NAR = None @@ -156,7 +152,7 @@ def run_eval(engines, disabled_engines, eval_name, dl): stats = {k: sum(v) / len(v) for k, v in stats.items()} - engines_stats.update(flatten_dict({ name: stats })) + engines_stats.update({ f'{name}.{eval_name}': stats }) iteration = engines.global_step engines_stats['it'] = iteration diff --git a/vall_e/webui.py b/vall_e/webui.py index 9872485..9824a02 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -65,6 +65,7 @@ def init_tts(restart=False): @gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser = argparse.ArgumentParser(allow_abbrev=False) + # I'm very sure I can procedurally generate this list parser.add_argument("--text", type=str, default=kwargs["text"]) parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) @@ -77,6 +78,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"]) parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"]) + parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"]) args, unknown = parser.parse_known_args() tmp = tempfile.NamedTemporaryFile(suffix='.wav') @@ -181,7 +183,7 @@ with ui: with gr.Column(scale=7): with gr.Row(): layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") - layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=3, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") + layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.") with gr.Row(): layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.") @@ -190,6 +192,7 @@ with ui: with gr.Row(): layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.") layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.") + layout["inference"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") with gr.Row(): layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")