From e19aa643a63b3c2f03b48e367da623048f507421 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 21 Jul 2024 19:12:03 -0500 Subject: [PATCH] cleaned up demo page creation, added option to pass in RVQ level sampling distribution for training --- data/demo/index.template.html | 53 +++++++------- vall_e/config.py | 2 +- vall_e/data.py | 10 +-- vall_e/demo.py | 125 ++++++++++++++++------------------ vall_e/models/ar_nar.py | 12 +++- vall_e/models/nar.py | 26 ++++--- 6 files changed, 117 insertions(+), 111 deletions(-) diff --git a/data/demo/index.template.html b/data/demo/index.template.html index ed12bf9..0c0224e 100644 --- a/data/demo/index.template.html +++ b/data/demo/index.template.html @@ -1,30 +1,27 @@ - - - - -

VALL-E Demo

-

Below are some samples from my VALL-E implementation: https://git.ecker.tech/mrq/vall-e/. I do not consider these to be state of the art. Below are samples from LibriSpeech, comparing against the samples the original VALL-E demo sampled.

- - - - - - - - - - ${ENTRIES} -
TextPromptGround TruthOur VALL-EOriginal VALL-EYourTTS
-

Below are some extra samples.

- - - - - - - - ${SAMPLES} -
TextPromptGround TruthOur VALL-E
- + + + + +

VALL-E Demo

+

${PREAMBLE}

+ + + + + + + + + ${LIBRISPEECH_SAMPLES} +
TextPromptGround TruthOur VALL-EOriginal VALL-EYourTTS
+ + + + + + + ${DATASET_SAMPLES} +
TextPromptGround TruthOur VALL-E
+ \ No newline at end of file diff --git a/vall_e/config.py b/vall_e/config.py index 721b05d..ced46cb 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -209,7 +209,7 @@ class ModelExperimentalSettings: audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings kv_heads: int = 0 # MHA or GQA (for supported backends) - p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely + p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary unified_position_ids: bool = True # False will generate position IDs partitioned for each section tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing diff --git a/vall_e/data.py b/vall_e/data.py index b57ba6d..ec13ea0 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -486,7 +486,7 @@ class Dataset(_Dataset): duration = self.duration_map[path] self.duration += duration - # only calc duration if we're tot going to order by duration + # only calc duration if we're going to order by duration if self.sampler_order != "duration": continue @@ -845,11 +845,6 @@ class Dataset(_Dataset): # might be better to decode => concat waveforms with silence in between => reencode # as you technically can't just append encodec sequences together like this without issues resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device ) - - """ - resps = resps[:, :cfg.model.resp_levels] - proms = proms[:, :cfg.model.resp_levels] - """ task = random.choice(self.tasks) @@ -1031,7 +1026,8 @@ class Dataset(_Dataset): text=text, proms=proms, resps=resps, - text_string=text_string, + + metadata=metadata, ) def head_(self, n): diff --git a/vall_e/demo.py b/vall_e/demo.py index 08d2dca..1e59ec5 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -26,7 +26,7 @@ from .config import cfg from .data import create_train_dataloader, create_val_dataloader from .emb.qnt import decode_to_file -from tqdm import tqdm +from tqdm import tqdm, trange def encode(path): return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8') @@ -42,6 +42,7 @@ def main(): parser.add_argument("--sample-from-dataset", action="store_true") parser.add_argument("--dataset-samples", type=int, default=0) parser.add_argument("--audio-path-root", type=str, default=None) + parser.add_argument("--preamble", type=str, default=None) parser.add_argument("--language", type=str, default="en") @@ -77,66 +78,28 @@ def main(): if not args.demo_dir: args.demo_dir = Path("./data/demo/") - entries = [] - - # pull from provided samples - sample_dir = args.demo_dir / "librispeech" - if sample_dir.exists(): - speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ] - sources = ["ms_valle", "yourtts"] - - # generate demo output - for dir in tqdm(speakers, desc=f"Generating demo for speaker"): - text = open(dir / "prompt.txt").read() - prompt = dir / "prompt.wav" - out_path = dir / "out" / "ours.wav" - - entries.append(( - text, - [ prompt, dir / "reference.wav", out_path ] + [ dir / "out" / f"{source}.wav" for source in sources ] - )) - - if args.skip_existing and out_path.exists(): - continue - - tts.inference( - text=text, - references=[prompt], - language=args.language, - out_path=out_path, - input_prompt_length=args.input_prompt_length, - max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, - ar_temp=args.ar_temp, nar_temp=args.nar_temp, - min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, - top_p=args.top_p, top_k=args.top_k, - repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, - length_penalty=args.length_penalty, - beam_width=args.beam_width, - mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, - seed=args.seed, - tqdm=False, - ) - - entries = [ - f'{text}'+ - "".join( [ - f'' - for audio in audios - ] )+ - '' - for text, audios in entries - ] + if not args.preamble: + args.preamble = "
".join([ + 'Below are some samples from my VALL-E implementation: https://git.ecker.tech/mrq/vall-e/.', + 'LibriSpeech, comparing against the samples the original VALL-E demo sampled.', + 'I do not consider these to be state of the art, as the model does not follow close to the prompt as I would like for general speakers.', + ]) # read html template html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read() - # create html table, in one messy line - # replace in our template - html = html.replace(r"${ENTRIES}", "\n".join(entries) ) - samples = [] + # replace values in our template + html = html.replace(r"${PREAMBLE}", args.preamble ) + + # pull from provided samples + samples_dirs = { + "librispeech": args.demo_dir / "librispeech", + } # pull from dataset samples if args.sample_from_dataset: + samples_dirs["dataset"] = args.demo_dir / "dataset" + print("Loading dataloader...") dataloader = create_train_dataloader() print("Loaded dataloader.") @@ -144,35 +107,62 @@ def main(): num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size length = len( dataloader.dataset ) - for i in range( num ): + for i in trange( num, desc="Sampling dataset for samples" ): idx = random.randint( 0, length ) batch = dataloader.dataset[idx] - dir = args.demo_dir / "samples" / f'{i}' + dir = args.demo_dir / "dataset" / f'{i}' (dir / "out").mkdir(parents=True, exist_ok=True) - text = batch["text_string"] + metadata = batch["metadata"] + + text = metadata["text"] + language = metadata["language"] prompt = dir / "prompt.wav" reference = dir / "reference.wav" out_path = dir / "out" / "ours.wav" + if args.skip_existing and out_path.exists(): + continue + + open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text ) + open( dir / "language.txt", "w", encoding="utf-8" ).write( language ) + + decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" ) + decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" ) + + for k, sample_dir in samples_dirs.items(): + if not sample_dir.exists(): + continue + + speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ] + sources = [ "ms_valle", "yourtts" ] + + samples = [] + + # generate demo output + for dir in tqdm(speakers, desc=f"Generating demo for speaker"): + text = open(dir / "prompt.txt").read() + language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en" + prompt = dir / "prompt.wav" + out_path = dir / "out" / "ours.wav" + + extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else [] + samples.append(( text, - [ prompt, reference, out_path ] + [ prompt, dir / "reference.wav", out_path ] + extra_sources )) if args.skip_existing and out_path.exists(): continue - decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" ) - decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" ) - tts.inference( text=text, references=[prompt], - language=args.language, + language=language, out_path=out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, @@ -187,18 +177,21 @@ def main(): tqdm=False, ) + # collate entries into HTML samples = [ - f'{text}'+ + f'\n\t\t\t\n\t\t\t\t{text}'+ "".join( [ - f'' + f'\n\t\t\t\t' for audio in audios ] )+ - '' + '\n\t\t\t' for text, audios in samples ] - html = html.replace(r"${SAMPLES}", "\n".join(samples) ) + # write audio into template + html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) ) + # write demo page open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html ) if __name__ == "__main__": diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index a598611..a9b8b76 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -144,6 +144,7 @@ class AR_NAR(Base): quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] else: # if p_rvq_levels == "auto": # makes higher levels less likely + """ def generate( lo=0, hi=8 ): index = lo p = random.random() @@ -151,8 +152,17 @@ class AR_NAR(Base): if p < 1.0 / (2 ** i): index = i return int(index) + """ - quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ] + # allow passing a specific distribution of RVQ levels + pool = p_rvq_levels if isinstance(p_rvq_levels, list) else [] + if not pool: + lo, hi = quant_level_range[0], quant_level_range[1] + for i in range( lo, hi ): + rep = hi - i + pool += [i] * rep + + quant_levels = [ random.choice( pool ) for i in range(batch_size) ] # these two are techinically equivalent if the audio embeddings handle things properly resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index a251340..571ccdd 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -143,15 +143,25 @@ class NAR(Base): quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] else: # if p_rvq_levels == "auto": # makes higher levels less likely - def generate( lo=0, hi=8 ): - index = lo - p = random.random() - for i in range(lo, hi): - if p < 1.0 / (2 ** i): - index = i - return int(index) + """ + def generate( lo=0, hi=8 ): + index = lo + p = random.random() + for i in range(lo, hi): + if p < 1.0 / (2 ** i): + index = i + return int(index) + """ - quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ] + # allow passing a specific distribution of RVQ levels + pool = p_rvq_levels if isinstance(p_rvq_levels, list) else [] + if not pool: + lo, hi = quant_level_range[0], quant_level_range[1] + for i in range( lo, hi ): + rep = hi - i + pool += [i] * rep + + quant_levels = [ random.choice( pool ) for i in range(batch_size) ] # clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC... for i in range(batch_size):