From 2ea978f318f1f82a27597b1408f7dd6e5917214f Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 10 Oct 2024 13:40:25 -0500 Subject: [PATCH] added --eval-random-text-prompts to use random text prompts for eval pass, added --random-prompts for demo page and --lora to use a sample with the lora disabled, probably finally fixed validation dataloader breaking on eval --- README.md | 4 +++ vall_e/data.py | 75 ++++++++++++++++++++++++++++++++++++++++--- vall_e/demo.py | 53 ++++++++++++++++++------------ vall_e/emb/similar.py | 4 +++ vall_e/inference.py | 8 +++++ vall_e/models/ar.py | 2 ++ vall_e/train.py | 58 +++++++++++++++++++++------------ vall_e/webui.py | 39 ++-------------------- 8 files changed, 161 insertions(+), 82 deletions(-) diff --git a/README.md b/README.md index e8136c3..1c3eb9b 100755 --- a/README.md +++ b/README.md @@ -96,6 +96,10 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t The `lr` command will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`. +Some additional flags can be passed as well: +* `--eval`: only run the evaluation / validation pass, then exit afterwards. +* `--eval-random-text-prompts`: use random text prompts for the evaluation pass, rather than the provided text prompts in the dataset. + ### Finetuning Finetuning can be done by training the full model, or using a LoRA. diff --git a/vall_e/data.py b/vall_e/data.py index 209e0de..d80b179 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -35,6 +35,72 @@ from tqdm.auto import tqdm _logger = logging.getLogger(__name__) +@cache +def get_random_prompts( validation=True, length=0, tokenized=False ): + sentences = [ + "The birch canoe slid on the smooth planks.", + "Glue the sheet to the dark blue background.", + "It's easy to tell the depth of a well.", + "These days a chicken leg is a rare dish.", + "Rice is often served in round bowls.", + "The juice of lemons makes fine punch.", + "The box was thrown beside the parked truck.", + "The hogs were fed chopped corn and garbage.", + "Four hours of steady work faced us.", + "A large size in stockings is hard to sell.", + "The boy was there when the sun rose.", + "A rod is used to catch pink salmon.", + "The source of the huge river is the clear spring.", + "Kick the ball straight and follow through.", + "Help the woman get back to her feet.", + "A pot of tea helps to pass the evening.", + "Smoky fires lack flame and heat.", + "The soft cushion broke the man's fall.", + "The salt breeze came across from the sea.", + "The girl at the booth sold fifty bonds.", + "The small pup gnawed a hole in the sock.", + "The fish twisted and turned on the bent hook.", + "Press the pants and sew a button on the vest.", + "The swan dive was far short of perfect.", + "The beauty of the view stunned the young boy.", + "Two blue fish swam in the tank.", + "Her purse was full of useless trash.", + "The colt reared and threw the tall rider.", + "It snowed, rained, and hailed the same morning.", + "Read verse out loud for pleasure.", + ] + + # Pull from validation dataset if existing + requested + if validation and cfg.dataset.validation: + paths = _load_paths(cfg.dataset.validation, type="validation") + paths = list(itertools.chain.from_iterable(paths.values())) + + for path in paths: + text_string = "" + if cfg.dataset.use_hdf5: + key = _get_hdf5_path(path) + + metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() } + text_string = metadata["text"] if "text" in metadata else "" + else: + _, metadata = _load_quants(path, return_metadata=True) + text_string = metadata["text"] if "text" in metadata else "" + + if len( text_string ) < length: + continue + + sentences.append( text_string ) + + if tokenized: + return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ] + + return sentences + +# samples a random text prompt +def get_random_prompt( *args, **kwargs ): + # Harvard sentences + return random.choice(get_random_prompts( *args, **kwargs )) + # fold into a typical LLM sequence (one embedding rather than split embeddings) def fold_inputs( text_list = [], @@ -718,7 +784,7 @@ class Dataset(_Dataset): if len(self.paths) == 0: raise ValueError(f"No valid path is found for {self.dataset_type}") - if self.sampler_type == "path": + if self.sampler_type == "path" and self.training: if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: self.sampler = BatchedOrderedSampler( self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways @@ -735,9 +801,7 @@ class Dataset(_Dataset): self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() } self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() } - # loading validation state dict causes issues - if self.dataset_type != "validation": - self.load_state_dict() + self.load_state_dict() @cached_property def sampler_state_dict_path(self): @@ -789,6 +853,9 @@ class Dataset(_Dataset): torch_save(state_dict, path) def load_state_dict(self, path = None): + if not self.training: + return + if path is None: path = self.sampler_state_dict_path diff --git a/vall_e/demo.py b/vall_e/demo.py index 80da923..1abfd8d 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -26,7 +26,7 @@ from pathlib import Path from .inference import TTS from .config import cfg -from .data import create_train_dataloader, create_val_dataloader +from .data import create_train_dataloader, create_val_dataloader, get_random_prompt from .emb.qnt import decode_to_file from tqdm import tqdm, trange @@ -78,6 +78,9 @@ def main(): parser.add_argument("--device", type=str, default=None) parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default=None) + + parser.add_argument("--random-prompts", action="store_true") + parser.add_argument("--lora", action="store_true") args = parser.parse_args() @@ -142,7 +145,7 @@ def main(): metadata = batch["metadata"] - text = metadata["text"] + text = get_random_prompt() if args.random_prompts else metadata["text"] language = metadata["language"].lower() prompt = dir / "prompt.wav" @@ -174,8 +177,9 @@ def main(): prompt = dir / "prompt.wav" reference = dir / "reference.wav" out_path = dir / "out" / "ours.wav" + out_path_lora = dir / "out" / "ours_lora.wav" - extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else [] + extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else []) samples.append(( text, @@ -185,27 +189,36 @@ def main(): if args.skip_existing and out_path.exists(): continue + kwargs = dict( + text=text, + references=[prompt], + language=language, + input_prompt_length=args.input_prompt_length, + max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, + ar_temp=args.ar_temp, nar_temp=args.nar_temp, + min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, + top_p=args.top_p, top_k=args.top_k, + repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, + length_penalty=args.length_penalty, + beam_width=args.beam_width, + mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, + seed=args.seed, + tqdm=False, + ) + + if args.lora: + tts.enable_lora() + try: + tts.inference( out_path=out_path_lora, **kwargs ) + except Exception as e: + print(f'Error while processing {out_path}: {e}') + tts.disable_lora() try: - tts.inference( - text=text, - references=[prompt], - language=language, - out_path=out_path, - input_prompt_length=args.input_prompt_length, - max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, - ar_temp=args.ar_temp, nar_temp=args.nar_temp, - min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, - top_p=args.top_p, top_k=args.top_k, - repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, - length_penalty=args.length_penalty, - beam_width=args.beam_width, - mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, - seed=args.seed, - tqdm=False, - ) + tts.inference( out_path=out_path, **kwargs ) except Exception as e: print(f'Error while processing {out_path}: {e}') + # collate entries into HTML samples = [ f'\n\t\t\t\n\t\t\t\t{text}'+ diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 8936094..3bd7193 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -221,6 +221,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--input-speaker", type=Path, default=None) + parser.add_argument("--input-voice", type=str, default=None) parser.add_argument("--use-dataset", action="store_true") parser.add_argument("--yaml", type=Path) @@ -254,6 +255,9 @@ def main(): if "LibriTTS-R" in speaker_name: speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox") """ + + if args.input_voice and speaker_name != args.input_voice: + return metadata_path = cfg.metadata_dir / f'{speaker_name}.json' metadata = json_read( metadata_path, default={} ) diff --git a/vall_e/inference.py b/vall_e/inference.py index ce8d1c3..9740c6a 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -16,6 +16,7 @@ from .utils import to_device, set_seed, wrapper as ml from .config import cfg, Config from .models import get_models +from .models.lora import enable_lora from .engines import load_engines, deepspeed_available from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize @@ -77,6 +78,13 @@ class TTS(): self.symmap = get_phone_symmap() _logger.info("Loaded model") + def enable_lora( self, enabled=True ): + for name, engine in self.engines.items(): + enable_lora( engine.module, mode = enabled ) + + def disable_lora( self ): + return self.enable_lora( enabled=False ) + def encode_text( self, text, language="en" ): # already a tensor, return it if isinstance( text, Tensor ): diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 383cba9..28aa653 100644 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -156,8 +156,10 @@ class AR(Base): ) # is AR + """ if cfg.lora is not None: enable_lora( self, cfg.lora.active_level( 0 ) ) + """ sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] stopped = torch.zeros(batch_size, device=device).bool() diff --git a/vall_e/train.py b/vall_e/train.py index 58811be..9f0b969 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -1,8 +1,8 @@ # todo: clean this mess up from .config import cfg -from .data import create_train_val_dataloader -from .emb import qnt +from .data import create_train_val_dataloader, get_random_prompt, tokenize +from .emb import qnt, g2p from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .data import fold_inputs, unfold_outputs @@ -57,10 +57,13 @@ def train_feeder(engine, batch): return loss, stats @torch.inference_mode() -def run_eval(engines, eval_name, dl): +def run_eval(engines, eval_name, dl, args=None): stats = defaultdict(list) stats['loss'] = [] + if cfg.evaluation.size == 0: + return + def process( name, batch, resps_list ): for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]): if len(hyp) == 0: @@ -84,16 +87,21 @@ def run_eval(engines, eval_name, dl): ref_path.parent.mkdir(parents=True, exist_ok=True) prom_path.parent.mkdir(parents=True, exist_ok=True) - ref_audio, sr = qnt.decode_to_file(ref, ref_path) hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path) + + if ref is not None: + ref_audio, sr = qnt.decode_to_file(ref, ref_path) + if prom is not None: prom_audio, sr = qnt.decode_to_file(prom, prom_path) - # pseudo loss calculation since we don't get the logits during eval - min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) - ref_audio = ref_audio[..., 0:min_length] - hyp_audio = hyp_audio[..., 0:min_length] - stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item()) + # naive loss calculation + # to-do: find a better way to calculate this / a better metric + if ref is not None: + min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) + ref_audio = ref_audio[..., 0:min_length] + hyp_audio = hyp_audio[..., 0:min_length] + stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item()) processed = 0 while processed < cfg.evaluation.size: @@ -105,19 +113,25 @@ def run_eval(engines, eval_name, dl): batch_size = len(batch["text"]) - processed += batch_size + # to-do: eval for text tasks + has_stt = False + for i, task in enumerate( batch["task"] ): + # easier to just change it to a tts task than drop stt tasks from the batch + if task == "stt": + # has_stt = True + batch["task"][i] = "tts" + batch["proms"][i] = batch["resps"][i][:75*3, :] + # random prompts requested + if args and args.eval_random_text_prompts and eval_name == "subtrain": + for i, _ in enumerate(batch["text"]): + batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device) + batch["resps"][i] = None + + processed += batch_size for name in engines: engine = engines[name] - # to-do: eval for text tasks - has_stt = False - for i, task in enumerate( batch["task"] ): - # easier to just change it to a tts task than drop stt tasks from the batch - if task == "stt": - # has_stt = True - batch["task"][i] = "tts" - batch["proms"][i] = batch["resps"][i][:75*3, :] kwargs = dict( text_list=batch["text"], @@ -157,7 +171,7 @@ def run_eval(engines, eval_name, dl): _logger.info(f"Validation Metrics (STT): {text_list}") - stats = {k: sum(v) / len(v) for k, v in stats.items()} + stats = {k: sum(v) / len(v) for k, v in stats.items() if v} engines_stats = { f'{name}.{eval_name}': stats, "it": engines.global_step, @@ -170,6 +184,8 @@ def run_eval(engines, eval_name, dl): def train(): parser = argparse.ArgumentParser("VALL-E TTS") parser.add_argument("--eval", action="store_true", default=None) + parser.add_argument("--eval-random-text-prompts", action="store_true", default=None) + #parser.add_argument("--eval-random-audio-prompts", action="store_true", default=None) args, unknown = parser.parse_known_args() # create log folder @@ -185,8 +201,8 @@ def train(): engines.eval() # wrapped in a try block because it's sometimes prone to breaking try: - run_eval(engines, "subtrain", subtrain_dl) - run_eval(engines, "val", val_dl) + run_eval(engines, "subtrain", subtrain_dl, args) + run_eval(engines, "val", val_dl, args) except Exception as e: _logger.warning(f"Error occurred while performing eval: {str(e)}") _logger.warning(traceback.format_exc()) diff --git a/vall_e/webui.py b/vall_e/webui.py index 7d5a5e7..f8307ca 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -19,7 +19,7 @@ from .train import train from .utils import get_devices, setup_logging, timer from .utils.io import json_read, json_stringify from .emb.qnt import decode_to_wave -from .data import get_lang_symmap +from .data import get_lang_symmap, get_random_prompt tts = None @@ -284,42 +284,7 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): while True: metrics = next(it) yield metrics -""" - -def get_random_prompt(): - harvard_sentences=[ - "The birch canoe slid on the smooth planks.", - "Glue the sheet to the dark blue background.", - "It's easy to tell the depth of a well.", - "These days a chicken leg is a rare dish.", - "Rice is often served in round bowls.", - "The juice of lemons makes fine punch.", - "The box was thrown beside the parked truck.", - "The hogs were fed chopped corn and garbage.", - "Four hours of steady work faced us.", - "A large size in stockings is hard to sell.", - "The boy was there when the sun rose.", - "A rod is used to catch pink salmon.", - "The source of the huge river is the clear spring.", - "Kick the ball straight and follow through.", - "Help the woman get back to her feet.", - "A pot of tea helps to pass the evening.", - "Smoky fires lack flame and heat.", - "The soft cushion broke the man's fall.", - "The salt breeze came across from the sea.", - "The girl at the booth sold fifty bonds.", - "The small pup gnawed a hole in the sock.", - "The fish twisted and turned on the bent hook.", - "Press the pants and sew a button on the vest.", - "The swan dive was far short of perfect.", - "The beauty of the view stunned the young boy.", - "Two blue fish swam in the tank.", - "Her purse was full of useless trash.", - "The colt reared and threw the tall rider.", - "It snowed, rained, and hailed the same morning.", - "Read verse out loud for pleasure.", - ] - return random.choice(harvard_sentences) +""" # setup args parser = argparse.ArgumentParser(allow_abbrev=False)