cleaned up demo page creation, added option to pass in RVQ level sampling distribution for training

This commit is contained in:
mrq 2024-07-21 19:12:03 -05:00
parent ba7ee8c0ee
commit e19aa643a6
6 changed files with 117 additions and 111 deletions

View File

@ -1,30 +1,27 @@
<html>
<head>
<meta charset="UTF-8">
</head>
<body>
<h1>VALL-E Demo</h1>
<p>Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>. I do not consider these to be state of the art. Below are samples from LibriSpeech, comparing against the samples the original VALL-E demo sampled.</p>
<table>
<tr>
<th>Text</th>
<th>Prompt</th>
<th>Ground Truth</th>
<th>Our VALL-E</th>
<th>Original VALL-E</th>
<th>YourTTS</th>
</tr>
${ENTRIES}
</table>
<p>Below are some extra samples.</p>
<table>
<tr>
<th>Text</th>
<th>Prompt</th>
<th>Ground Truth</th>
<th>Our VALL-E</th>
</tr>
${SAMPLES}
</table>
</body>
<head>
<meta charset="UTF-8">
</head>
<body>
<h1>VALL-E Demo</h1>
<p>${PREAMBLE}</p>
<table>
<tr>
<th>Text</th>
<th>Prompt</th>
<th>Ground Truth</th>
<th>Our VALL-E</th>
<th>Original VALL-E</th>
<th>YourTTS</th>
</tr>${LIBRISPEECH_SAMPLES}
</table>
<table>
<tr>
<th>Text</th>
<th>Prompt</th>
<th>Ground Truth</th>
<th>Our VALL-E</th>
</tr>${DATASET_SAMPLES}
</table>
</body>
</html>

View File

@ -209,7 +209,7 @@ class ModelExperimentalSettings:
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
kv_heads: int = 0 # MHA or GQA (for supported backends)
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing

View File

@ -486,7 +486,7 @@ class Dataset(_Dataset):
duration = self.duration_map[path]
self.duration += duration
# only calc duration if we're tot going to order by duration
# only calc duration if we're going to order by duration
if self.sampler_order != "duration":
continue
@ -845,11 +845,6 @@ class Dataset(_Dataset):
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
"""
resps = resps[:, :cfg.model.resp_levels]
proms = proms[:, :cfg.model.resp_levels]
"""
task = random.choice(self.tasks)
@ -1031,7 +1026,8 @@ class Dataset(_Dataset):
text=text,
proms=proms,
resps=resps,
text_string=text_string,
metadata=metadata,
)
def head_(self, n):

View File

@ -26,7 +26,7 @@ from .config import cfg
from .data import create_train_dataloader, create_val_dataloader
from .emb.qnt import decode_to_file
from tqdm import tqdm
from tqdm import tqdm, trange
def encode(path):
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
@ -42,6 +42,7 @@ def main():
parser.add_argument("--sample-from-dataset", action="store_true")
parser.add_argument("--dataset-samples", type=int, default=0)
parser.add_argument("--audio-path-root", type=str, default=None)
parser.add_argument("--preamble", type=str, default=None)
parser.add_argument("--language", type=str, default="en")
@ -77,66 +78,28 @@ def main():
if not args.demo_dir:
args.demo_dir = Path("./data/demo/")
entries = []
# pull from provided samples
sample_dir = args.demo_dir / "librispeech"
if sample_dir.exists():
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
sources = ["ms_valle", "yourtts"]
# generate demo output
for dir in tqdm(speakers, desc=f"Generating demo for speaker"):
text = open(dir / "prompt.txt").read()
prompt = dir / "prompt.wav"
out_path = dir / "out" / "ours.wav"
entries.append((
text,
[ prompt, dir / "reference.wav", out_path ] + [ dir / "out" / f"{source}.wav" for source in sources ]
))
if args.skip_existing and out_path.exists():
continue
tts.inference(
text=text,
references=[prompt],
language=args.language,
out_path=out_path,
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
seed=args.seed,
tqdm=False,
)
entries = [
f'<tr><td>{text}</td>'+
"".join( [
f'<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
for audio in audios
] )+
'</tr>'
for text, audios in entries
]
if not args.preamble:
args.preamble = "<br>".join([
'Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>.',
'LibriSpeech, comparing against the samples the original VALL-E demo sampled.',
'I do not consider these to be state of the art, as the model does not follow close to the prompt as I would like for general speakers.',
])
# read html template
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
# create html table, in one messy line
# replace in our template
html = html.replace(r"${ENTRIES}", "\n".join(entries) )
samples = []
# replace values in our template
html = html.replace(r"${PREAMBLE}", args.preamble )
# pull from provided samples
samples_dirs = {
"librispeech": args.demo_dir / "librispeech",
}
# pull from dataset samples
if args.sample_from_dataset:
samples_dirs["dataset"] = args.demo_dir / "dataset"
print("Loading dataloader...")
dataloader = create_train_dataloader()
print("Loaded dataloader.")
@ -144,35 +107,62 @@ def main():
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
length = len( dataloader.dataset )
for i in range( num ):
for i in trange( num, desc="Sampling dataset for samples" ):
idx = random.randint( 0, length )
batch = dataloader.dataset[idx]
dir = args.demo_dir / "samples" / f'{i}'
dir = args.demo_dir / "dataset" / f'{i}'
(dir / "out").mkdir(parents=True, exist_ok=True)
text = batch["text_string"]
metadata = batch["metadata"]
text = metadata["text"]
language = metadata["language"]
prompt = dir / "prompt.wav"
reference = dir / "reference.wav"
out_path = dir / "out" / "ours.wav"
if args.skip_existing and out_path.exists():
continue
open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text )
open( dir / "language.txt", "w", encoding="utf-8" ).write( language )
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
for k, sample_dir in samples_dirs.items():
if not sample_dir.exists():
continue
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
sources = [ "ms_valle", "yourtts" ]
samples = []
# generate demo output
for dir in tqdm(speakers, desc=f"Generating demo for speaker"):
text = open(dir / "prompt.txt").read()
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
prompt = dir / "prompt.wav"
out_path = dir / "out" / "ours.wav"
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else []
samples.append((
text,
[ prompt, reference, out_path ]
[ prompt, dir / "reference.wav", out_path ] + extra_sources
))
if args.skip_existing and out_path.exists():
continue
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
tts.inference(
text=text,
references=[prompt],
language=args.language,
language=language,
out_path=out_path,
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
@ -187,18 +177,21 @@ def main():
tqdm=False,
)
# collate entries into HTML
samples = [
f'<tr><td>{text}</td>'+
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
"".join( [
f'<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
for audio in audios
] )+
'</tr>'
'\n\t\t\t</tr>'
for text, audios in samples
]
html = html.replace(r"${SAMPLES}", "\n".join(samples) )
# write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
# write demo page
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
if __name__ == "__main__":

View File

@ -144,6 +144,7 @@ class AR_NAR(Base):
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
else: # if p_rvq_levels == "auto":
# makes higher levels less likely
"""
def generate( lo=0, hi=8 ):
index = lo
p = random.random()
@ -151,8 +152,17 @@ class AR_NAR(Base):
if p < 1.0 / (2 ** i):
index = i
return int(index)
"""
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
# allow passing a specific distribution of RVQ levels
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
if not pool:
lo, hi = quant_level_range[0], quant_level_range[1]
for i in range( lo, hi ):
rep = hi - i
pool += [i] * rep
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
# these two are techinically equivalent if the audio embeddings handle things properly
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]

View File

@ -143,15 +143,25 @@ class NAR(Base):
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
else: # if p_rvq_levels == "auto":
# makes higher levels less likely
def generate( lo=0, hi=8 ):
index = lo
p = random.random()
for i in range(lo, hi):
if p < 1.0 / (2 ** i):
index = i
return int(index)
"""
def generate( lo=0, hi=8 ):
index = lo
p = random.random()
for i in range(lo, hi):
if p < 1.0 / (2 ** i):
index = i
return int(index)
"""
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
# allow passing a specific distribution of RVQ levels
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
if not pool:
lo, hi = quant_level_range[0], quant_level_range[1]
for i in range( lo, hi ):
rep = hi - i
pool += [i] * rep
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
for i in range(batch_size):