diff --git a/data/demo/index.template.html b/data/demo/index.template.html
index 788ecb0..23eeddc 100644
--- a/data/demo/index.template.html
+++ b/data/demo/index.template.html
@@ -25,7 +25,6 @@
Text |
Prompt |
- Our VALL-E (No LoRA) |
Our VALL-E |
Ground Truth |
diff --git a/vall_e/data.py b/vall_e/data.py
index 5310c49..9c2972e 100755
--- a/vall_e/data.py
+++ b/vall_e/data.py
@@ -36,7 +36,7 @@ from tqdm.auto import tqdm
_logger = logging.getLogger(__name__)
@cache
-def get_random_prompts( validation=True, length=0, tokenized=False ):
+def get_random_prompts( validation=True, min_length=0, min_duration=6, tokenized=False ):
sentences = [
"The birch canoe slid on the smooth planks.",
"Glue the sheet to the dark blue background.",
@@ -76,21 +76,27 @@ def get_random_prompts( validation=True, length=0, tokenized=False ):
paths = list(itertools.chain.from_iterable(paths.values()))
for path in paths:
+ duration = 0
text_string = ""
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
+ metadata = process_artifact_metadata( { "metadata": metadata } )
text_string = metadata["text"] if "text" in metadata else ""
+ duration = metadata['duration'] if "duration" in metadata else 0
else:
_, metadata = _load_quants(path, return_metadata=True)
+ metadata = process_artifact_metadata( { "metadata": metadata } )
text_string = metadata["text"] if "text" in metadata else ""
+ duration = metadata['duration'] if "duration" in metadata else 0
- if len( text_string ) < length:
+ if len( text_string ) < min_length or duration < min_duration:
continue
sentences.append( text_string )
+ # tokenize here because our harvard sentences need to be phonemized anyways
if tokenized:
return [ torch.tensor( tokenize( encode_phns( text ) ) ).to(dtype=torch.uint8) for text in sentences ]
diff --git a/vall_e/demo.py b/vall_e/demo.py
index 33910f2..c7cfa00 100644
--- a/vall_e/demo.py
+++ b/vall_e/demo.py
@@ -19,6 +19,7 @@ import argparse
import base64
import random
import logging
+import time
_logger = logging.getLogger(__name__)
@@ -42,6 +43,7 @@ def main():
parser.add_argument("--demo-dir", type=Path, default=None)
parser.add_argument("--skip-existing", action="store_true")
+ parser.add_argument("--dataset-dir-name", type=str, default="dataset")
parser.add_argument("--sample-from-dataset", action="store_true")
parser.add_argument("--skip-loading-dataloader", action="store_true")
parser.add_argument("--dataset-samples", type=int, default=0)
@@ -118,8 +120,8 @@ def main():
"librispeech": args.demo_dir / "librispeech",
}
- if (args.demo_dir / "dataset").exists():
- samples_dirs["dataset"] = args.demo_dir / "dataset"
+ if (args.demo_dir / args.dataset_dir_name).exists():
+ samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
# pull from dataset samples
if args.sample_from_dataset:
@@ -127,7 +129,7 @@ def main():
cfg.dataset.sample_type = "path" if args.lora else "speaker"
cfg.dataset.tasks_list = [ 'tts' ]
- samples_dirs["dataset"] = args.demo_dir / "dataset"
+ samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
_logger.info("Loading dataloader...")
dataloader = create_train_dataloader()
@@ -139,7 +141,7 @@ def main():
for i in trange( num, desc="Sampling dataset for samples" ):
batch = dataloader.dataset[i]
- dir = args.demo_dir / "dataset" / f'{i}'
+ dir = args.demo_dir / args.dataset_dir_name / f'{i}'
(dir / "out").mkdir(parents=True, exist_ok=True)
@@ -181,14 +183,19 @@ def main():
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
+ if not args.random_prompts:
+ extra_sources += [ reference ]
+
samples.append((
text,
- [ prompt, out_path ] + extra_sources + [ reference ],
+ [ prompt, out_path ] + extra_sources,
))
if args.skip_existing and out_path.exists():
continue
+ seed = args.seed if args.seed else int(time.time())
+
kwargs = dict(
text=text,
references=[prompt],
@@ -202,17 +209,19 @@ def main():
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
- seed=args.seed,
+ seed=seed,
tqdm=False,
)
if args.lora:
- tts.enable_lora()
+ tts.enable_lora() # I don't think this is necessary with the below
+ kwargs["use_lora"] = True
try:
tts.inference( out_path=out_path_lora, **kwargs )
except Exception as e:
print(f'Error while processing {out_path}: {e}')
tts.disable_lora()
+ kwargs["use_lora"] = False
try:
tts.inference( out_path=out_path, **kwargs )
except Exception as e:
@@ -233,8 +242,11 @@ def main():
# write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
- if not args.lora:
- html = html.replace("\n\t\t\t\t\tOur VALL-E (No LoRA) | ", "")
+ if args.lora:
+ if args.random_prompts:
+ html = html.replace("Our VALL-E | \n\t\t\t\t\tGround Truth | ", "Our VALL-E (No LoRA) | \n\t\t\t\t\tOur VALL-E (LoRA) | ")
+ else:
+ html = html.replace("Our VALL-E | ", "Our VALL-E (No LoRA) | \n\t\t\t\t\t<Our VALL-E (LoRA) | ")
# write demo page
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
diff --git a/vall_e/inference.py b/vall_e/inference.py
index 9740c6a..5dd1459 100755
--- a/vall_e/inference.py
+++ b/vall_e/inference.py
@@ -211,6 +211,7 @@ class TTS():
out_path=None,
tqdm=True,
+ use_lora=None,
):
lines = text.split("\n")
@@ -255,6 +256,7 @@ class TTS():
sampling_dry_allowed_length=dry_allowed_length,
disable_tqdm=not tqdm,
+ use_lora=use_lora,
)
else:
raise Exception("!")
@@ -298,6 +300,7 @@ class TTS():
sampling_dry_allowed_length=dry_allowed_length,
disable_tqdm=not tqdm,
+ use_lora=use_lora,
)
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
@@ -309,6 +312,7 @@ class TTS():
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
disable_tqdm=not tqdm,
+ use_lora=use_lora,
)
elif model_len is not None:
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that
@@ -320,6 +324,7 @@ class TTS():
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
disable_tqdm=not tqdm,
+ use_lora=use_lora,
)
else:
raise Exception("!")
diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py
index 28aa653..3d59da8 100644
--- a/vall_e/models/ar.py
+++ b/vall_e/models/ar.py
@@ -58,6 +58,7 @@ class AR(Base):
sampling_dry_allowed_length=2,
disable_tqdm=False,
+ use_lora=None,
):
device = text_list[0].device
batch_size = len(text_list)
@@ -156,10 +157,8 @@ class AR(Base):
)
# is AR
- """
if cfg.lora is not None:
- enable_lora( self, cfg.lora.active_level( 0 ) )
- """
+ enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py
index b4a6c61..5e08c3f 100644
--- a/vall_e/models/ar_nar.py
+++ b/vall_e/models/ar_nar.py
@@ -65,6 +65,7 @@ class AR_NAR(Base):
sampling_dry_allowed_length=2,
disable_tqdm=False,
+ use_lora=None,
):
text_task = [ "stt" ]
@@ -204,7 +205,7 @@ class AR_NAR(Base):
break
if cfg.lora is not None:
- enable_lora( self, cfg.lora.active_level( level ) )
+ enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
@@ -243,14 +244,11 @@ class AR_NAR(Base):
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
- if cfg.lora is not None:
- enable_lora( self )
-
return prev_list
# is AR
if cfg.lora is not None:
- enable_lora( self, cfg.lora.active_level( 0 ) )
+ enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
# STT
start_slice = [ 0 for _ in range(batch_size) ]