added rudimentary demo page creator (currently just embeds base64 wavs into the page, need to test not doing that)
This commit is contained in:
parent
d53038a9e4
commit
d87b492295
30
data/demo/index.template.html
Normal file
30
data/demo/index.template.html
Normal file
|
@ -0,0 +1,30 @@
|
|||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
</head>
|
||||
<body>
|
||||
<h1>VALL-E Demo</h1>
|
||||
<p>Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>. I do not consider these to be state of the art. Below are samples from LibriSpeech, comparing against the samples the original VALL-E demo sampled.</p>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Text</th>
|
||||
<th>Prompt</th>
|
||||
<th>Ground Truth</th>
|
||||
<th>Our VALL-E</th>
|
||||
<th>Original VALL-E</th>
|
||||
<th>YourTTS</th>
|
||||
</tr>
|
||||
${ENTRIES}
|
||||
</table>
|
||||
<p>Below are some extra samples.</p>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Text</th>
|
||||
<th>Prompt</th>
|
||||
<th>Ground Truth</th>
|
||||
<th>Our VALL-E</th>
|
||||
</tr>
|
||||
${SAMPLES}
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
|
@ -575,7 +575,9 @@ class Dataset(_Dataset):
|
|||
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
|
||||
self.load_state_dict()
|
||||
# loading validation state dict causes issues
|
||||
if self.dataset_type != "validation":
|
||||
self.load_state_dict()
|
||||
|
||||
@cached_property
|
||||
def sampler_state_dict_path(self):
|
||||
|
@ -804,12 +806,14 @@ class Dataset(_Dataset):
|
|||
|
||||
lang = metadata["language"] if "language" in metadata else None
|
||||
tone = metadata["tone"] if "tone" in metadata else None
|
||||
text_string = metadata["text"] if "text" in metadata else None
|
||||
else:
|
||||
resps, metadata = _load_quants(path, return_metadata=True)
|
||||
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||
|
||||
lang = metadata["language"] if "language" in metadata else None
|
||||
tone = metadata["tone"] if "tone" in metadata else None
|
||||
text_string = metadata["text"] if "text" in metadata else None
|
||||
|
||||
if not lang:
|
||||
lang = self.get_language(spkr_group)
|
||||
|
@ -1027,6 +1031,7 @@ class Dataset(_Dataset):
|
|||
text=text,
|
||||
proms=proms,
|
||||
resps=resps,
|
||||
text_string=text_string,
|
||||
)
|
||||
|
||||
def head_(self, n):
|
||||
|
@ -1080,6 +1085,31 @@ def create_datasets():
|
|||
|
||||
return train_dataset, val_dataset
|
||||
|
||||
def create_train_dataloader():
|
||||
train_dataset = Dataset( training=True )
|
||||
train_dl = _create_dataloader(train_dataset, training=True)
|
||||
|
||||
_logger.info(str(train_dataset.phone_symmap))
|
||||
_logger.info(str(train_dataset.spkr_symmap))
|
||||
_logger.info(str(train_dataset.spkr_group_symmap))
|
||||
|
||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
||||
|
||||
return train_dl
|
||||
|
||||
def create_val_dataloader():
|
||||
val_dataset = Dataset( training=False )
|
||||
val_dl = _create_dataloader(val_dataset, training=False)
|
||||
|
||||
_logger.info(str(val_dataset.phone_symmap))
|
||||
_logger.info(str(val_dataset.spkr_symmap))
|
||||
_logger.info(str(val_dataset.spkr_group_symmap))
|
||||
|
||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
||||
|
||||
return val_dl
|
||||
|
||||
def create_train_val_dataloader():
|
||||
train_dataset, val_dataset = create_datasets()
|
||||
|
|
202
vall_e/demo.py
Normal file
202
vall_e/demo.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
A helper script to generate a demo page.
|
||||
|
||||
Layout as expected:
|
||||
./data/demo/:
|
||||
{speaker ID}:
|
||||
out:
|
||||
ours.wav (generated)
|
||||
ms_valle.wav
|
||||
yourtts.wav
|
||||
prompt.txt (text to generate)
|
||||
prompt.wav (reference clip to serve as the prompt)
|
||||
reference.wav (ground truth utterance)
|
||||
|
||||
Will also generate samples from a provided datset, if requested.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import random
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .inference import TTS
|
||||
from .config import cfg
|
||||
from .data import create_train_dataloader, create_val_dataloader
|
||||
from .emb.qnt import decode_to_file
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
def encode(path):
|
||||
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
||||
|
||||
# Would be downright sugoi if I could incorporate this with into __main__
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("VALL-E TTS Demo")
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--demo-dir", type=Path, default=None)
|
||||
parser.add_argument("--skip-existing", action="store_true")
|
||||
parser.add_argument("--sample-from-dataset", action="store_true")
|
||||
parser.add_argument("--dataset-samples", type=int, default=0)
|
||||
parser.add_argument("--audio-path-root", type=str, default=None)
|
||||
|
||||
parser.add_argument("--language", type=str, default="en")
|
||||
|
||||
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||
parser.add_argument("--max-nar-levels", type=int, default=7)
|
||||
|
||||
parser.add_argument("--ar-temp", type=float, default=1.0)
|
||||
parser.add_argument("--nar-temp", type=float, default=0.0)
|
||||
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
|
||||
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
|
||||
parser.add_argument("--input-prompt-length", type=float, default=3.0)
|
||||
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=16)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--beam-width", type=int, default=0)
|
||||
|
||||
parser.add_argument("--mirostat-tau", type=float, default=0)
|
||||
parser.add_argument("--mirostat-eta", type=float, default=0)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
|
||||
if not args.demo_dir:
|
||||
args.demo_dir = Path("./data/demo/")
|
||||
|
||||
entries = []
|
||||
|
||||
# pull from provided samples
|
||||
sample_dir = args.demo_dir / "librispeech"
|
||||
if sample_dir.exists():
|
||||
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
||||
sources = ["ms_valle", "yourtts"]
|
||||
|
||||
# generate demo output
|
||||
for dir in tqdm(speakers, desc=f"Generating demo for speaker"):
|
||||
text = open(dir / "prompt.txt").read()
|
||||
prompt = dir / "prompt.wav"
|
||||
out_path = dir / "out" / "ours.wav"
|
||||
|
||||
entries.append((
|
||||
text,
|
||||
[ prompt, dir / "reference.wav", out_path ] + [ dir / "out" / f"{source}.wav" for source in sources ]
|
||||
))
|
||||
|
||||
if args.skip_existing and out_path.exists():
|
||||
continue
|
||||
|
||||
tts.inference(
|
||||
text=text,
|
||||
references=[prompt],
|
||||
language=args.language,
|
||||
out_path=out_path,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
seed=args.seed,
|
||||
tqdm=False,
|
||||
)
|
||||
|
||||
entries = [
|
||||
f'<tr><td>{text}</td>'+
|
||||
"".join( [
|
||||
f'<td><audio controls="controls" autobuffer="autobuffer"><source src="{args.audio_path_root + audio if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||
for audio in audios
|
||||
] )+
|
||||
'</tr>'
|
||||
for text, audios in entries
|
||||
]
|
||||
|
||||
# read html template
|
||||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||
# create html table, in one messy line
|
||||
# replace in our template
|
||||
html = html.replace(r"${ENTRIES}", "\n".join(entries) )
|
||||
|
||||
samples = []
|
||||
|
||||
# pull from dataset samples
|
||||
if args.sample_from_dataset:
|
||||
print("Loading dataloader...")
|
||||
dataloader = create_train_dataloader()
|
||||
print("Loaded dataloader.")
|
||||
|
||||
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
|
||||
|
||||
length = len( dataloader.dataset )
|
||||
for i in range( num ):
|
||||
idx = random.randint( 0, length )
|
||||
batch = dataloader.dataset[idx]
|
||||
|
||||
dir = args.demo_dir / "samples" / f'{i}'
|
||||
|
||||
(dir / "out").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
text = batch["text_string"]
|
||||
|
||||
prompt = dir / "prompt.wav"
|
||||
reference = dir / "reference.wav"
|
||||
out_path = dir / "out" / "ours.wav"
|
||||
|
||||
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
|
||||
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
|
||||
|
||||
samples.append((
|
||||
text,
|
||||
[ prompt, reference, out_path ]
|
||||
))
|
||||
|
||||
tts.inference(
|
||||
text=text,
|
||||
references=[prompt],
|
||||
language=args.language,
|
||||
out_path=out_path,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
seed=args.seed,
|
||||
tqdm=False,
|
||||
)
|
||||
|
||||
samples = [
|
||||
f'<tr><td>{text}</td>'+
|
||||
"".join( [
|
||||
f'<td><audio controls="controls" autobuffer="autobuffer"><source src="{args.audio_path_root + audio if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||
for audio in audios
|
||||
] )+
|
||||
'</tr>'
|
||||
for text, audios in samples
|
||||
]
|
||||
|
||||
html = html.replace(r"${SAMPLES}", "\n".join(samples) )
|
||||
|
||||
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -352,6 +352,10 @@ class Engines(dict[str, Engine]):
|
|||
"userdata": userdata,
|
||||
"config": config
|
||||
}
|
||||
|
||||
if lora is None:
|
||||
del state_dict['lora']
|
||||
|
||||
if callback:
|
||||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||
|
||||
|
|
|
@ -81,12 +81,32 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
|||
|
||||
return state_dict
|
||||
|
||||
def split_classifier_heads( state_dict, config = cfg.model, save_path = None, dtype = None):
|
||||
levels = config.max_levels
|
||||
|
||||
if "classifier.weight" not in state_dict['module']:
|
||||
return state_dict
|
||||
|
||||
# copy to new AudioClassifier
|
||||
for i in range(levels):
|
||||
tokens = 1025 if i == 0 else 1024
|
||||
|
||||
# trim per RVQ level (since level 0 has a stop token)
|
||||
state_dict['module'][f'classifiers.proj.{i}.weight'] = state_dict['module']['classifier.weight'][:tokens, :]
|
||||
state_dict['module'][f'classifiers.proj.{i}.bias'] = state_dict['module']['classifier.bias'][:tokens]
|
||||
|
||||
# delete old weights
|
||||
del state_dict['module']['classifier.weight']
|
||||
del state_dict['module']['classifier.bias']
|
||||
|
||||
return state_dict
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
||||
parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads
|
||||
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
|
@ -98,6 +118,8 @@ def main():
|
|||
callback = convert_to_hf
|
||||
elif args.lora:
|
||||
callback = extract_lora
|
||||
elif args.split_classifiers:
|
||||
callback = split_classifier_heads
|
||||
|
||||
if args.hf and args.lora:
|
||||
raise Exception("Requesting more than one callback")
|
||||
|
|
|
@ -140,7 +140,9 @@ class TTS():
|
|||
|
||||
seed = None,
|
||||
|
||||
out_path=None
|
||||
out_path=None,
|
||||
|
||||
tqdm=True,
|
||||
):
|
||||
lines = text.split("\n")
|
||||
|
||||
|
@ -194,6 +196,8 @@ class TTS():
|
|||
sampling_beam_width=beam_width,
|
||||
sampling_mirostat_tau=mirostat_tau,
|
||||
sampling_mirostat_eta=mirostat_eta,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
)
|
||||
resps_list = model_nar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
||||
|
@ -202,15 +206,19 @@ class TTS():
|
|||
sampling_min_temperature=min_nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
)
|
||||
elif model_len is not None:
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10 ) # don't need more than that
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that
|
||||
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list,
|
||||
max_levels=max_nar_levels,
|
||||
sampling_temperature=nar_temp,
|
||||
sampling_min_temperature=min_nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
)
|
||||
else:
|
||||
raise Exception("!")
|
||||
|
|
|
@ -114,6 +114,8 @@ class AR_NAR(Base):
|
|||
sampling_beam_width: int = 0,
|
||||
sampling_mirostat_tau: float = 0.0,
|
||||
sampling_mirostat_eta: float = 0.1,
|
||||
|
||||
disable_tqdm=False,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
@ -206,7 +208,7 @@ class AR_NAR(Base):
|
|||
|
||||
prev_list = resps_list
|
||||
|
||||
for n in trange( max_levels, desc="NAR" ):
|
||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||
level = prev_list[0].shape[-1]
|
||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
break
|
||||
|
@ -271,7 +273,7 @@ class AR_NAR(Base):
|
|||
scores = [ 1.0 ] * sampling_beam_width
|
||||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.causal_size), desc="AR"):
|
||||
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
||||
resps_list = self._unsqueeze_list(sequence_list)
|
||||
|
||||
inputs = self.inputs(
|
||||
|
|
|
@ -111,6 +111,8 @@ class NAR(Base):
|
|||
sampling_beam_width: int = 0,
|
||||
sampling_mirostat_tau: float = 0.0,
|
||||
sampling_mirostat_eta: float = 0.1,
|
||||
|
||||
disable_tqdm=False,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
@ -188,7 +190,7 @@ class NAR(Base):
|
|||
prev_list = [ torch.Tensor([ self.stop_token for _ in range(resp_len) ]).to(device=device, dtype=torch.int16) for resp_len in len_list ]
|
||||
|
||||
start = True
|
||||
for n in trange( max_levels, desc="NAR" ):
|
||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||
level = 0 if n == 0 else prev_list[0].shape[-1]
|
||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
break
|
||||
|
@ -243,7 +245,7 @@ class NAR(Base):
|
|||
stop_token = 10
|
||||
task_list = [ "len" for _ in range(batch_size) ]
|
||||
|
||||
for n in trange(10, desc="AR"):
|
||||
for n in trange(10, desc="AR", disable=disable_tqdm):
|
||||
len_list = sequence_list
|
||||
|
||||
inputs = self.inputs(
|
||||
|
|
Loading…
Reference in New Issue
Block a user