2024-07-20 01:49:40 +00:00
|
|
|
"""
|
|
|
|
A helper script to generate a demo page.
|
|
|
|
|
|
|
|
Layout as expected:
|
|
|
|
./data/demo/:
|
|
|
|
{speaker ID}:
|
|
|
|
out:
|
|
|
|
ours.wav (generated)
|
|
|
|
ms_valle.wav
|
|
|
|
yourtts.wav
|
|
|
|
prompt.txt (text to generate)
|
|
|
|
prompt.wav (reference clip to serve as the prompt)
|
|
|
|
reference.wav (ground truth utterance)
|
|
|
|
|
|
|
|
Will also generate samples from a provided datset, if requested.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import base64
|
|
|
|
import random
|
2024-08-29 18:27:16 +00:00
|
|
|
import logging
|
2024-10-11 00:04:12 +00:00
|
|
|
import time
|
2024-08-29 18:27:16 +00:00
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
2024-07-20 01:49:40 +00:00
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
from .inference import TTS
|
|
|
|
from .config import cfg
|
2024-10-10 18:40:25 +00:00
|
|
|
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
2024-07-20 01:49:40 +00:00
|
|
|
from .emb.qnt import decode_to_file
|
|
|
|
|
2024-07-22 00:12:03 +00:00
|
|
|
from tqdm import tqdm, trange
|
2024-07-20 01:49:40 +00:00
|
|
|
|
|
|
|
def encode(path):
|
|
|
|
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
|
|
|
|
|
|
|
# Would be downright sugoi if I could incorporate this with into __main__
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser("VALL-E TTS Demo")
|
|
|
|
|
|
|
|
parser.add_argument("--yaml", type=Path, default=None)
|
|
|
|
|
|
|
|
parser.add_argument("--demo-dir", type=Path, default=None)
|
|
|
|
parser.add_argument("--skip-existing", action="store_true")
|
2024-10-11 00:04:12 +00:00
|
|
|
parser.add_argument("--dataset-dir-name", type=str, default="dataset")
|
2024-07-20 01:49:40 +00:00
|
|
|
parser.add_argument("--sample-from-dataset", action="store_true")
|
2024-09-28 14:49:45 +00:00
|
|
|
parser.add_argument("--skip-loading-dataloader", action="store_true")
|
2024-07-20 01:49:40 +00:00
|
|
|
parser.add_argument("--dataset-samples", type=int, default=0)
|
|
|
|
parser.add_argument("--audio-path-root", type=str, default=None)
|
2024-07-22 00:12:03 +00:00
|
|
|
parser.add_argument("--preamble", type=str, default=None)
|
2024-10-11 00:40:01 +00:00
|
|
|
parser.add_argument("--output-filename", type=str, default="index.html")
|
2024-07-20 01:49:40 +00:00
|
|
|
|
|
|
|
parser.add_argument("--language", type=str, default="en")
|
|
|
|
|
|
|
|
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
|
|
|
|
parser.add_argument("--max-nar-levels", type=int, default=7)
|
|
|
|
|
|
|
|
parser.add_argument("--ar-temp", type=float, default=1.0)
|
|
|
|
parser.add_argument("--nar-temp", type=float, default=0.0)
|
|
|
|
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
|
|
|
|
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
|
2024-07-22 00:31:13 +00:00
|
|
|
parser.add_argument("--input-prompt-length", type=float, default=0.0)
|
2024-07-20 01:49:40 +00:00
|
|
|
|
|
|
|
parser.add_argument("--top-p", type=float, default=1.0)
|
2024-07-22 04:21:37 +00:00
|
|
|
parser.add_argument("--top-k", type=int, default=0)
|
2024-10-12 17:09:17 +00:00
|
|
|
parser.add_argument("--min-p", type=float, default=0.0)
|
2024-07-20 01:49:40 +00:00
|
|
|
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
|
|
|
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
|
|
|
parser.add_argument("--length-penalty", type=float, default=0.0)
|
|
|
|
parser.add_argument("--beam-width", type=int, default=0)
|
|
|
|
|
|
|
|
parser.add_argument("--mirostat-tau", type=float, default=0)
|
|
|
|
parser.add_argument("--mirostat-eta", type=float, default=0)
|
2024-08-05 00:56:21 +00:00
|
|
|
|
|
|
|
parser.add_argument("--dry-multiplier", type=float, default=0)
|
|
|
|
parser.add_argument("--dry-base", type=float, default=1.75)
|
|
|
|
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
2024-07-20 01:49:40 +00:00
|
|
|
|
2024-10-12 16:27:55 +00:00
|
|
|
parser.add_argument("--entropix-sampling", action="store_true")
|
|
|
|
|
2024-07-20 01:49:40 +00:00
|
|
|
parser.add_argument("--seed", type=int, default=None)
|
|
|
|
|
|
|
|
parser.add_argument("--device", type=str, default=None)
|
|
|
|
parser.add_argument("--amp", action="store_true")
|
|
|
|
parser.add_argument("--dtype", type=str, default=None)
|
2024-10-10 18:40:25 +00:00
|
|
|
|
|
|
|
parser.add_argument("--random-prompts", action="store_true")
|
|
|
|
parser.add_argument("--lora", action="store_true")
|
2024-10-12 16:27:55 +00:00
|
|
|
parser.add_argument("--comparison", action="store_true")
|
2024-07-20 01:49:40 +00:00
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
|
|
|
|
|
|
|
if not args.demo_dir:
|
|
|
|
args.demo_dir = Path("./data/demo/")
|
|
|
|
|
2024-07-22 00:12:03 +00:00
|
|
|
if not args.preamble:
|
|
|
|
args.preamble = "<br>".join([
|
|
|
|
'Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>.',
|
2024-09-28 14:49:45 +00:00
|
|
|
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
|
2024-07-22 00:12:03 +00:00
|
|
|
])
|
2024-07-20 01:49:40 +00:00
|
|
|
|
2024-10-12 16:27:55 +00:00
|
|
|
# comparison kwargs
|
|
|
|
comparison_kwargs = {
|
|
|
|
"enabled": False,
|
|
|
|
"titles": [],
|
|
|
|
"suffix": "_after",
|
|
|
|
"before": {},
|
|
|
|
"after": {}
|
|
|
|
}
|
|
|
|
|
|
|
|
if args.lora:
|
|
|
|
comparison_kwargs["enabled"] = True
|
|
|
|
comparison_kwargs["suffix"] = "_lora"
|
|
|
|
comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
|
|
|
|
comparison_kwargs["before"]["use_lora"] = True
|
|
|
|
comparison_kwargs["after"]["use_lora"] = False
|
|
|
|
# to-do: make this user definable
|
|
|
|
elif args.comparison:
|
|
|
|
comparison_kwargs["enabled"] = True
|
|
|
|
comparison_kwargs["suffix"] = "_entropix"
|
|
|
|
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
2024-10-12 17:09:17 +00:00
|
|
|
|
2024-10-12 16:27:55 +00:00
|
|
|
comparison_kwargs["before"]["entropix_sampling"] = True
|
2024-10-12 17:09:17 +00:00
|
|
|
comparison_kwargs["before"]["ar_temp"] = 0.666
|
|
|
|
comparison_kwargs["before"]["top_k"] = 27
|
|
|
|
comparison_kwargs["before"]["top_p"] = 0.9
|
2024-10-12 16:27:55 +00:00
|
|
|
comparison_kwargs["after"]["entropix_sampling"] = False
|
2024-10-12 17:09:17 +00:00
|
|
|
comparison_kwargs["after"]["ar_temp"] = args.ar_temp
|
|
|
|
comparison_kwargs["after"]["top_k"] = args.top_k
|
|
|
|
comparison_kwargs["after"]["top_p"] = args.top_p
|
2024-10-12 16:27:55 +00:00
|
|
|
|
|
|
|
|
2024-07-20 01:49:40 +00:00
|
|
|
# read html template
|
|
|
|
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
|
|
|
|
2024-07-22 00:12:03 +00:00
|
|
|
# replace values in our template
|
|
|
|
html = html.replace(r"${PREAMBLE}", args.preamble )
|
2024-07-22 00:31:13 +00:00
|
|
|
html = html.replace(r"${SETTINGS}", str(dict(
|
|
|
|
input_prompt_length=args.input_prompt_length,
|
|
|
|
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
|
|
|
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
|
|
|
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
2024-10-12 17:09:17 +00:00
|
|
|
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
|
2024-07-22 00:31:13 +00:00
|
|
|
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
|
|
|
length_penalty=args.length_penalty,
|
|
|
|
beam_width=args.beam_width,
|
|
|
|
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
2024-08-05 00:56:21 +00:00
|
|
|
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
2024-10-12 16:27:55 +00:00
|
|
|
entropix_sampling=args.entropix_sampling,
|
2024-07-22 00:31:13 +00:00
|
|
|
)) )
|
2024-07-22 00:12:03 +00:00
|
|
|
|
|
|
|
# pull from provided samples
|
|
|
|
samples_dirs = {
|
|
|
|
"librispeech": args.demo_dir / "librispeech",
|
|
|
|
}
|
2024-07-20 01:49:40 +00:00
|
|
|
|
2024-10-11 00:04:12 +00:00
|
|
|
if (args.demo_dir / args.dataset_dir_name).exists():
|
|
|
|
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
2024-09-28 14:49:45 +00:00
|
|
|
|
2024-07-20 01:49:40 +00:00
|
|
|
# pull from dataset samples
|
|
|
|
if args.sample_from_dataset:
|
2024-07-22 04:21:37 +00:00
|
|
|
cfg.dataset.cache = False
|
2024-10-12 16:27:55 +00:00
|
|
|
cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
|
2024-09-28 15:50:26 +00:00
|
|
|
cfg.dataset.tasks_list = [ 'tts' ]
|
|
|
|
|
2024-10-11 00:04:12 +00:00
|
|
|
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
2024-07-22 00:12:03 +00:00
|
|
|
|
2024-09-28 14:49:45 +00:00
|
|
|
_logger.info("Loading dataloader...")
|
|
|
|
dataloader = create_train_dataloader()
|
|
|
|
_logger.info("Loaded dataloader.")
|
2024-07-20 01:49:40 +00:00
|
|
|
|
2024-10-10 18:52:37 +00:00
|
|
|
length = min(len( dataloader.dataset ), cfg.evaluation.batch_size)
|
2024-09-28 15:50:26 +00:00
|
|
|
num = args.dataset_samples if args.dataset_samples else length
|
|
|
|
|
2024-09-28 14:49:45 +00:00
|
|
|
for i in trange( num, desc="Sampling dataset for samples" ):
|
2024-09-28 15:50:26 +00:00
|
|
|
batch = dataloader.dataset[i]
|
2024-07-20 01:49:40 +00:00
|
|
|
|
2024-10-11 00:04:12 +00:00
|
|
|
dir = args.demo_dir / args.dataset_dir_name / f'{i}'
|
2024-07-20 01:49:40 +00:00
|
|
|
|
2024-09-28 14:49:45 +00:00
|
|
|
(dir / "out").mkdir(parents=True, exist_ok=True)
|
2024-07-20 01:49:40 +00:00
|
|
|
|
2024-09-28 14:49:45 +00:00
|
|
|
metadata = batch["metadata"]
|
2024-07-22 00:12:03 +00:00
|
|
|
|
2024-10-10 18:40:25 +00:00
|
|
|
text = get_random_prompt() if args.random_prompts else metadata["text"]
|
2024-09-28 15:50:26 +00:00
|
|
|
language = metadata["language"].lower()
|
2024-09-28 14:49:45 +00:00
|
|
|
|
|
|
|
prompt = dir / "prompt.wav"
|
|
|
|
reference = dir / "reference.wav"
|
|
|
|
out_path = dir / "out" / "ours.wav"
|
2024-07-20 01:49:40 +00:00
|
|
|
|
2024-09-28 14:49:45 +00:00
|
|
|
if args.skip_existing and out_path.exists():
|
|
|
|
continue
|
2024-07-22 00:12:03 +00:00
|
|
|
|
2024-09-28 14:49:45 +00:00
|
|
|
open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text )
|
|
|
|
open( dir / "language.txt", "w", encoding="utf-8" ).write( language )
|
2024-09-26 23:56:57 +00:00
|
|
|
|
2024-09-28 14:49:45 +00:00
|
|
|
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
|
|
|
|
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
|
2024-07-22 00:12:03 +00:00
|
|
|
|
|
|
|
for k, sample_dir in samples_dirs.items():
|
|
|
|
if not sample_dir.exists():
|
|
|
|
continue
|
|
|
|
|
|
|
|
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
|
|
|
sources = [ "ms_valle", "yourtts" ]
|
|
|
|
|
|
|
|
samples = []
|
|
|
|
|
|
|
|
# generate demo output
|
2024-07-22 00:17:25 +00:00
|
|
|
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
|
2024-07-22 00:12:03 +00:00
|
|
|
text = open(dir / "prompt.txt").read()
|
|
|
|
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
|
|
|
|
prompt = dir / "prompt.wav"
|
2024-07-22 04:21:37 +00:00
|
|
|
reference = dir / "reference.wav"
|
2024-07-22 00:12:03 +00:00
|
|
|
out_path = dir / "out" / "ours.wav"
|
2024-10-12 16:27:55 +00:00
|
|
|
out_path_comparison = dir / "out" / f"ours_{comparison_kwargs["suffix"]}.wav"
|
2024-07-22 00:12:03 +00:00
|
|
|
|
2024-10-12 16:27:55 +00:00
|
|
|
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_comparison ] if comparison_kwargs["enabled"] else [])
|
2024-07-22 00:12:03 +00:00
|
|
|
|
2024-10-12 02:18:26 +00:00
|
|
|
if not args.random_prompts or k == "librispeech":
|
2024-10-11 00:04:12 +00:00
|
|
|
extra_sources += [ reference ]
|
|
|
|
|
2024-07-20 01:49:40 +00:00
|
|
|
samples.append((
|
|
|
|
text,
|
2024-10-11 00:04:12 +00:00
|
|
|
[ prompt, out_path ] + extra_sources,
|
2024-07-20 01:49:40 +00:00
|
|
|
))
|
|
|
|
|
2024-07-20 02:07:17 +00:00
|
|
|
if args.skip_existing and out_path.exists():
|
|
|
|
continue
|
|
|
|
|
2024-10-11 00:04:12 +00:00
|
|
|
seed = args.seed if args.seed else int(time.time())
|
|
|
|
|
2024-10-10 18:40:25 +00:00
|
|
|
kwargs = dict(
|
|
|
|
text=text,
|
|
|
|
references=[prompt],
|
|
|
|
language=language,
|
|
|
|
input_prompt_length=args.input_prompt_length,
|
|
|
|
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
|
|
|
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
|
|
|
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
|
|
|
top_p=args.top_p, top_k=args.top_k,
|
|
|
|
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
|
|
|
length_penalty=args.length_penalty,
|
|
|
|
beam_width=args.beam_width,
|
|
|
|
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
2024-10-12 16:27:55 +00:00
|
|
|
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
|
|
|
entropix_sampling=args.entropix_sampling,
|
2024-10-11 00:04:12 +00:00
|
|
|
seed=seed,
|
2024-10-10 18:40:25 +00:00
|
|
|
tqdm=False,
|
|
|
|
)
|
|
|
|
|
2024-10-12 16:29:16 +00:00
|
|
|
def safe_inference( out_path=out_path ):
|
2024-10-10 18:40:25 +00:00
|
|
|
try:
|
2024-10-12 16:29:16 +00:00
|
|
|
tts.inference( out_path=out_path, **kwargs )
|
2024-10-10 18:40:25 +00:00
|
|
|
except Exception as e:
|
|
|
|
print(f'Error while processing {out_path}: {e}')
|
2024-10-12 16:27:55 +00:00
|
|
|
|
|
|
|
if comparison_kwargs["enabled"]:
|
|
|
|
kwargs.update( comparison_kwargs["before"] )
|
2024-10-12 16:29:16 +00:00
|
|
|
safe_inference(out_path_comparison)
|
2024-10-12 16:27:55 +00:00
|
|
|
kwargs.update( comparison_kwargs["after"] )
|
|
|
|
|
|
|
|
safe_inference()
|
2024-07-20 01:49:40 +00:00
|
|
|
|
2024-10-10 18:40:25 +00:00
|
|
|
|
2024-07-22 00:12:03 +00:00
|
|
|
# collate entries into HTML
|
2024-07-20 01:49:40 +00:00
|
|
|
samples = [
|
2024-07-22 00:12:03 +00:00
|
|
|
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
2024-07-20 01:49:40 +00:00
|
|
|
"".join( [
|
2024-07-22 00:12:03 +00:00
|
|
|
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
2024-07-20 01:49:40 +00:00
|
|
|
for audio in audios
|
|
|
|
] )+
|
2024-07-22 00:12:03 +00:00
|
|
|
'\n\t\t\t</tr>'
|
2024-07-20 01:49:40 +00:00
|
|
|
for text, audios in samples
|
|
|
|
]
|
|
|
|
|
2024-07-22 00:12:03 +00:00
|
|
|
# write audio into template
|
|
|
|
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
2024-07-20 01:49:40 +00:00
|
|
|
|
2024-10-12 16:27:55 +00:00
|
|
|
if comparison_kwargs["enabled"]:
|
|
|
|
before, after = comparison_kwargs["titles"]
|
2024-10-11 00:04:12 +00:00
|
|
|
if args.random_prompts:
|
2024-10-12 16:27:55 +00:00
|
|
|
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
2024-10-11 00:04:12 +00:00
|
|
|
else:
|
2024-10-12 16:27:55 +00:00
|
|
|
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
2024-10-10 18:52:37 +00:00
|
|
|
|
2024-07-22 00:12:03 +00:00
|
|
|
# write demo page
|
2024-10-11 00:40:01 +00:00
|
|
|
open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html )
|
2024-07-20 01:49:40 +00:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|