diff --git a/data/demo/index.template.html b/data/demo/index.template.html new file mode 100644 index 0000000..ed12bf9 --- /dev/null +++ b/data/demo/index.template.html @@ -0,0 +1,30 @@ + + + + + +

VALL-E Demo

+

Below are some samples from my VALL-E implementation: https://git.ecker.tech/mrq/vall-e/. I do not consider these to be state of the art. Below are samples from LibriSpeech, comparing against the samples the original VALL-E demo sampled.

+ + + + + + + + + + ${ENTRIES} +
TextPromptGround TruthOur VALL-EOriginal VALL-EYourTTS
+

Below are some extra samples.

+ + + + + + + + ${SAMPLES} +
TextPromptGround TruthOur VALL-E
+ + \ No newline at end of file diff --git a/vall_e/data.py b/vall_e/data.py index 34b0f22..b57ba6d 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -575,7 +575,9 @@ class Dataset(_Dataset): self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() } self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() } - self.load_state_dict() + # loading validation state dict causes issues + if self.dataset_type != "validation": + self.load_state_dict() @cached_property def sampler_state_dict_path(self): @@ -804,12 +806,14 @@ class Dataset(_Dataset): lang = metadata["language"] if "language" in metadata else None tone = metadata["tone"] if "tone" in metadata else None + text_string = metadata["text"] if "text" in metadata else None else: resps, metadata = _load_quants(path, return_metadata=True) text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype) lang = metadata["language"] if "language" in metadata else None tone = metadata["tone"] if "tone" in metadata else None + text_string = metadata["text"] if "text" in metadata else None if not lang: lang = self.get_language(spkr_group) @@ -1027,6 +1031,7 @@ class Dataset(_Dataset): text=text, proms=proms, resps=resps, + text_string=text_string, ) def head_(self, n): @@ -1080,6 +1085,31 @@ def create_datasets(): return train_dataset, val_dataset +def create_train_dataloader(): + train_dataset = Dataset( training=True ) + train_dl = _create_dataloader(train_dataset, training=True) + + _logger.info(str(train_dataset.phone_symmap)) + _logger.info(str(train_dataset.spkr_symmap)) + _logger.info(str(train_dataset.spkr_group_symmap)) + + _logger.info(f"#samples (train): {len(train_dataset)}.") + _logger.info(f"#duration (train): {str(train_dataset.duration)}.") + + return train_dl + +def create_val_dataloader(): + val_dataset = Dataset( training=False ) + val_dl = _create_dataloader(val_dataset, training=False) + + _logger.info(str(val_dataset.phone_symmap)) + _logger.info(str(val_dataset.spkr_symmap)) + _logger.info(str(val_dataset.spkr_group_symmap)) + + _logger.info(f"#samples (val): {len(val_dataset)}.") + _logger.info(f"#duration (val): {str(val_dataset.duration)}.") + + return val_dl def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() diff --git a/vall_e/demo.py b/vall_e/demo.py new file mode 100644 index 0000000..9fe8645 --- /dev/null +++ b/vall_e/demo.py @@ -0,0 +1,202 @@ +""" +A helper script to generate a demo page. + +Layout as expected: + ./data/demo/: + {speaker ID}: + out: + ours.wav (generated) + ms_valle.wav + yourtts.wav + prompt.txt (text to generate) + prompt.wav (reference clip to serve as the prompt) + reference.wav (ground truth utterance) + +Will also generate samples from a provided datset, if requested. +""" + +import argparse +import base64 +import random + +from pathlib import Path + +from .inference import TTS +from .config import cfg +from .data import create_train_dataloader, create_val_dataloader +from .emb.qnt import decode_to_file + +from tqdm import tqdm + +def encode(path): + return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8') + +# Would be downright sugoi if I could incorporate this with into __main__ +def main(): + parser = argparse.ArgumentParser("VALL-E TTS Demo") + + parser.add_argument("--yaml", type=Path, default=None) + + parser.add_argument("--demo-dir", type=Path, default=None) + parser.add_argument("--skip-existing", action="store_true") + parser.add_argument("--sample-from-dataset", action="store_true") + parser.add_argument("--dataset-samples", type=int, default=0) + parser.add_argument("--audio-path-root", type=str, default=None) + + parser.add_argument("--language", type=str, default="en") + + parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second) + parser.add_argument("--max-nar-levels", type=int, default=7) + + parser.add_argument("--ar-temp", type=float, default=1.0) + parser.add_argument("--nar-temp", type=float, default=0.0) + parser.add_argument("--min-ar-temp", type=float, default=-1.0) + parser.add_argument("--min-nar-temp", type=float, default=-1.0) + parser.add_argument("--input-prompt-length", type=float, default=3.0) + + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=16) + parser.add_argument("--repetition-penalty", type=float, default=1.0) + parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) + parser.add_argument("--length-penalty", type=float, default=0.0) + parser.add_argument("--beam-width", type=int, default=0) + + parser.add_argument("--mirostat-tau", type=float, default=0) + parser.add_argument("--mirostat-eta", type=float, default=0) + + parser.add_argument("--seed", type=int, default=None) + + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--amp", action="store_true") + parser.add_argument("--dtype", type=str, default=None) + + args = parser.parse_args() + + tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp ) + + if not args.demo_dir: + args.demo_dir = Path("./data/demo/") + + entries = [] + + # pull from provided samples + sample_dir = args.demo_dir / "librispeech" + if sample_dir.exists(): + speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ] + sources = ["ms_valle", "yourtts"] + + # generate demo output + for dir in tqdm(speakers, desc=f"Generating demo for speaker"): + text = open(dir / "prompt.txt").read() + prompt = dir / "prompt.wav" + out_path = dir / "out" / "ours.wav" + + entries.append(( + text, + [ prompt, dir / "reference.wav", out_path ] + [ dir / "out" / f"{source}.wav" for source in sources ] + )) + + if args.skip_existing and out_path.exists(): + continue + + tts.inference( + text=text, + references=[prompt], + language=args.language, + out_path=out_path, + input_prompt_length=args.input_prompt_length, + max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, + ar_temp=args.ar_temp, nar_temp=args.nar_temp, + min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, + top_p=args.top_p, top_k=args.top_k, + repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, + length_penalty=args.length_penalty, + beam_width=args.beam_width, + mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, + seed=args.seed, + tqdm=False, + ) + + entries = [ + f'{text}'+ + "".join( [ + f'' + for audio in audios + ] )+ + '' + for text, audios in entries + ] + + # read html template + html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read() + # create html table, in one messy line + # replace in our template + html = html.replace(r"${ENTRIES}", "\n".join(entries) ) + + samples = [] + + # pull from dataset samples + if args.sample_from_dataset: + print("Loading dataloader...") + dataloader = create_train_dataloader() + print("Loaded dataloader.") + + num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size + + length = len( dataloader.dataset ) + for i in range( num ): + idx = random.randint( 0, length ) + batch = dataloader.dataset[idx] + + dir = args.demo_dir / "samples" / f'{i}' + + (dir / "out").mkdir(parents=True, exist_ok=True) + + text = batch["text_string"] + + prompt = dir / "prompt.wav" + reference = dir / "reference.wav" + out_path = dir / "out" / "ours.wav" + + decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" ) + decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" ) + + samples.append(( + text, + [ prompt, reference, out_path ] + )) + + tts.inference( + text=text, + references=[prompt], + language=args.language, + out_path=out_path, + input_prompt_length=args.input_prompt_length, + max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, + ar_temp=args.ar_temp, nar_temp=args.nar_temp, + min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, + top_p=args.top_p, top_k=args.top_k, + repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, + length_penalty=args.length_penalty, + beam_width=args.beam_width, + mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, + seed=args.seed, + tqdm=False, + ) + + samples = [ + f'{text}'+ + "".join( [ + f'' + for audio in audios + ] )+ + '' + for text, audios in samples + ] + + html = html.replace(r"${SAMPLES}", "\n".join(samples) ) + + open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html ) + +if __name__ == "__main__": + main() diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 0232f28..3312f7b 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -352,6 +352,10 @@ class Engines(dict[str, Engine]): "userdata": userdata, "config": config } + + if lora is None: + del state_dict['lora'] + if callback: state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) diff --git a/vall_e/export.py b/vall_e/export.py index b6a8b19..5649f50 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -81,12 +81,32 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ): return state_dict +def split_classifier_heads( state_dict, config = cfg.model, save_path = None, dtype = None): + levels = config.max_levels + + if "classifier.weight" not in state_dict['module']: + return state_dict + + # copy to new AudioClassifier + for i in range(levels): + tokens = 1025 if i == 0 else 1024 + + # trim per RVQ level (since level 0 has a stop token) + state_dict['module'][f'classifiers.proj.{i}.weight'] = state_dict['module']['classifier.weight'][:tokens, :] + state_dict['module'][f'classifiers.proj.{i}.bias'] = state_dict['module']['classifier.bias'][:tokens] + + # delete old weights + del state_dict['module']['classifier.weight'] + del state_dict['module']['classifier.bias'] + + return state_dict def main(): parser = argparse.ArgumentParser("Save trained model to path.") parser.add_argument("--module-only", action='store_true') parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style parser.add_argument("--lora", action='store_true', default=None) # exports LoRA + parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to args, unknown = parser.parse_known_args() @@ -98,6 +118,8 @@ def main(): callback = convert_to_hf elif args.lora: callback = extract_lora + elif args.split_classifiers: + callback = split_classifier_heads if args.hf and args.lora: raise Exception("Requesting more than one callback") diff --git a/vall_e/inference.py b/vall_e/inference.py index 6b9a7e2..c38b494 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -140,7 +140,9 @@ class TTS(): seed = None, - out_path=None + out_path=None, + + tqdm=True, ): lines = text.split("\n") @@ -194,6 +196,8 @@ class TTS(): sampling_beam_width=beam_width, sampling_mirostat_tau=mirostat_tau, sampling_mirostat_eta=mirostat_eta, + + disable_tqdm=not tqdm, ) resps_list = model_nar( text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, @@ -202,15 +206,19 @@ class TTS(): sampling_min_temperature=min_nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, + + disable_tqdm=not tqdm, ) elif model_len is not None: - len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10 ) # don't need more than that + len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_min_temperature=min_nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, + + disable_tqdm=not tqdm, ) else: raise Exception("!") diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 220e045..a598611 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -114,6 +114,8 @@ class AR_NAR(Base): sampling_beam_width: int = 0, sampling_mirostat_tau: float = 0.0, sampling_mirostat_eta: float = 0.1, + + disable_tqdm=False, ): device = text_list[0].device batch_size = len(text_list) @@ -206,7 +208,7 @@ class AR_NAR(Base): prev_list = resps_list - for n in trange( max_levels, desc="NAR" ): + for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): level = prev_list[0].shape[-1] if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels break @@ -271,7 +273,7 @@ class AR_NAR(Base): scores = [ 1.0 ] * sampling_beam_width # get next in sequence - for n in trange(max_steps // max(1, self.causal_size), desc="AR"): + for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): resps_list = self._unsqueeze_list(sequence_list) inputs = self.inputs( diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index d3b803d..a251340 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -111,6 +111,8 @@ class NAR(Base): sampling_beam_width: int = 0, sampling_mirostat_tau: float = 0.0, sampling_mirostat_eta: float = 0.1, + + disable_tqdm=False, ): device = text_list[0].device batch_size = len(text_list) @@ -188,7 +190,7 @@ class NAR(Base): prev_list = [ torch.Tensor([ self.stop_token for _ in range(resp_len) ]).to(device=device, dtype=torch.int16) for resp_len in len_list ] start = True - for n in trange( max_levels, desc="NAR" ): + for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): level = 0 if n == 0 else prev_list[0].shape[-1] if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels break @@ -243,7 +245,7 @@ class NAR(Base): stop_token = 10 task_list = [ "len" for _ in range(batch_size) ] - for n in trange(10, desc="AR"): + for n in trange(10, desc="AR", disable=disable_tqdm): len_list = sequence_list inputs = self.inputs(