diff --git a/data/demo/index.template.html b/data/demo/index.template.html index 788ecb0..23eeddc 100644 --- a/data/demo/index.template.html +++ b/data/demo/index.template.html @@ -25,7 +25,6 @@ Text Prompt - Our VALL-E (No LoRA) Our VALL-E Ground Truth diff --git a/vall_e/data.py b/vall_e/data.py index 5310c49..9c2972e 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -36,7 +36,7 @@ from tqdm.auto import tqdm _logger = logging.getLogger(__name__) @cache -def get_random_prompts( validation=True, length=0, tokenized=False ): +def get_random_prompts( validation=True, min_length=0, min_duration=6, tokenized=False ): sentences = [ "The birch canoe slid on the smooth planks.", "Glue the sheet to the dark blue background.", @@ -76,21 +76,27 @@ def get_random_prompts( validation=True, length=0, tokenized=False ): paths = list(itertools.chain.from_iterable(paths.values())) for path in paths: + duration = 0 text_string = "" if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() } + metadata = process_artifact_metadata( { "metadata": metadata } ) text_string = metadata["text"] if "text" in metadata else "" + duration = metadata['duration'] if "duration" in metadata else 0 else: _, metadata = _load_quants(path, return_metadata=True) + metadata = process_artifact_metadata( { "metadata": metadata } ) text_string = metadata["text"] if "text" in metadata else "" + duration = metadata['duration'] if "duration" in metadata else 0 - if len( text_string ) < length: + if len( text_string ) < min_length or duration < min_duration: continue sentences.append( text_string ) + # tokenize here because our harvard sentences need to be phonemized anyways if tokenized: return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ] diff --git a/vall_e/demo.py b/vall_e/demo.py index 33910f2..c7cfa00 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -19,6 +19,7 @@ import argparse import base64 import random import logging +import time _logger = logging.getLogger(__name__) @@ -42,6 +43,7 @@ def main(): parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--skip-existing", action="store_true") + parser.add_argument("--dataset-dir-name", type=str, default="dataset") parser.add_argument("--sample-from-dataset", action="store_true") parser.add_argument("--skip-loading-dataloader", action="store_true") parser.add_argument("--dataset-samples", type=int, default=0) @@ -118,8 +120,8 @@ def main(): "librispeech": args.demo_dir / "librispeech", } - if (args.demo_dir / "dataset").exists(): - samples_dirs["dataset"] = args.demo_dir / "dataset" + if (args.demo_dir / args.dataset_dir_name).exists(): + samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name # pull from dataset samples if args.sample_from_dataset: @@ -127,7 +129,7 @@ def main(): cfg.dataset.sample_type = "path" if args.lora else "speaker" cfg.dataset.tasks_list = [ 'tts' ] - samples_dirs["dataset"] = args.demo_dir / "dataset" + samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name _logger.info("Loading dataloader...") dataloader = create_train_dataloader() @@ -139,7 +141,7 @@ def main(): for i in trange( num, desc="Sampling dataset for samples" ): batch = dataloader.dataset[i] - dir = args.demo_dir / "dataset" / f'{i}' + dir = args.demo_dir / args.dataset_dir_name / f'{i}' (dir / "out").mkdir(parents=True, exist_ok=True) @@ -181,14 +183,19 @@ def main(): extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else []) + if not args.random_prompts: + extra_sources += [ reference ] + samples.append(( text, - [ prompt, out_path ] + extra_sources + [ reference ], + [ prompt, out_path ] + extra_sources, )) if args.skip_existing and out_path.exists(): continue + seed = args.seed if args.seed else int(time.time()) + kwargs = dict( text=text, references=[prompt], @@ -202,17 +209,19 @@ def main(): length_penalty=args.length_penalty, beam_width=args.beam_width, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, - seed=args.seed, + seed=seed, tqdm=False, ) if args.lora: - tts.enable_lora() + tts.enable_lora() # I don't think this is necessary with the below + kwargs["use_lora"] = True try: tts.inference( out_path=out_path_lora, **kwargs ) except Exception as e: print(f'Error while processing {out_path}: {e}') tts.disable_lora() + kwargs["use_lora"] = False try: tts.inference( out_path=out_path, **kwargs ) except Exception as e: @@ -233,8 +242,11 @@ def main(): # write audio into template html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) ) - if not args.lora: - html = html.replace("\n\t\t\t\t\tOur VALL-E (No LoRA)", "") + if args.lora: + if args.random_prompts: + html = html.replace("Our VALL-E\n\t\t\t\t\tGround Truth", "Our VALL-E (No LoRA)\n\t\t\t\t\tOur VALL-E (LoRA)") + else: + html = html.replace("Our VALL-E", "Our VALL-E (No LoRA)\n\t\t\t\t\t<Our VALL-E (LoRA)") # write demo page open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html ) diff --git a/vall_e/inference.py b/vall_e/inference.py index 9740c6a..5dd1459 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -211,6 +211,7 @@ class TTS(): out_path=None, tqdm=True, + use_lora=None, ): lines = text.split("\n") @@ -255,6 +256,7 @@ class TTS(): sampling_dry_allowed_length=dry_allowed_length, disable_tqdm=not tqdm, + use_lora=use_lora, ) else: raise Exception("!") @@ -298,6 +300,7 @@ class TTS(): sampling_dry_allowed_length=dry_allowed_length, disable_tqdm=not tqdm, + use_lora=use_lora, ) resps_list = model_nar( text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, @@ -309,6 +312,7 @@ class TTS(): sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, disable_tqdm=not tqdm, + use_lora=use_lora, ) elif model_len is not None: len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that @@ -320,6 +324,7 @@ class TTS(): sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, disable_tqdm=not tqdm, + use_lora=use_lora, ) else: raise Exception("!") diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 28aa653..3d59da8 100644 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -58,6 +58,7 @@ class AR(Base): sampling_dry_allowed_length=2, disable_tqdm=False, + use_lora=None, ): device = text_list[0].device batch_size = len(text_list) @@ -156,10 +157,8 @@ class AR(Base): ) # is AR - """ if cfg.lora is not None: - enable_lora( self, cfg.lora.active_level( 0 ) ) - """ + enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] stopped = torch.zeros(batch_size, device=device).bool() diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index b4a6c61..5e08c3f 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -65,6 +65,7 @@ class AR_NAR(Base): sampling_dry_allowed_length=2, disable_tqdm=False, + use_lora=None, ): text_task = [ "stt" ] @@ -204,7 +205,7 @@ class AR_NAR(Base): break if cfg.lora is not None: - enable_lora( self, cfg.lora.active_level( level ) ) + enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level) @@ -243,14 +244,11 @@ class AR_NAR(Base): prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] - if cfg.lora is not None: - enable_lora( self ) - return prev_list # is AR if cfg.lora is not None: - enable_lora( self, cfg.lora.active_level( 0 ) ) + enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) # STT start_slice = [ 0 for _ in range(batch_size) ]