cleaned up demo page creation, added option to pass in RVQ level sampling distribution for training
This commit is contained in:
parent
ba7ee8c0ee
commit
e19aa643a6
|
@ -1,30 +1,27 @@
|
|||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
</head>
|
||||
<body>
|
||||
<h1>VALL-E Demo</h1>
|
||||
<p>Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>. I do not consider these to be state of the art. Below are samples from LibriSpeech, comparing against the samples the original VALL-E demo sampled.</p>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Text</th>
|
||||
<th>Prompt</th>
|
||||
<th>Ground Truth</th>
|
||||
<th>Our VALL-E</th>
|
||||
<th>Original VALL-E</th>
|
||||
<th>YourTTS</th>
|
||||
</tr>
|
||||
${ENTRIES}
|
||||
</table>
|
||||
<p>Below are some extra samples.</p>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Text</th>
|
||||
<th>Prompt</th>
|
||||
<th>Ground Truth</th>
|
||||
<th>Our VALL-E</th>
|
||||
</tr>
|
||||
${SAMPLES}
|
||||
</table>
|
||||
</body>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
</head>
|
||||
<body>
|
||||
<h1>VALL-E Demo</h1>
|
||||
<p>${PREAMBLE}</p>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Text</th>
|
||||
<th>Prompt</th>
|
||||
<th>Ground Truth</th>
|
||||
<th>Our VALL-E</th>
|
||||
<th>Original VALL-E</th>
|
||||
<th>YourTTS</th>
|
||||
</tr>${LIBRISPEECH_SAMPLES}
|
||||
</table>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Text</th>
|
||||
<th>Prompt</th>
|
||||
<th>Ground Truth</th>
|
||||
<th>Our VALL-E</th>
|
||||
</tr>${DATASET_SAMPLES}
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
|
@ -209,7 +209,7 @@ class ModelExperimentalSettings:
|
|||
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||
p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
|
||||
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
||||
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
||||
|
|
|
@ -486,7 +486,7 @@ class Dataset(_Dataset):
|
|||
duration = self.duration_map[path]
|
||||
self.duration += duration
|
||||
|
||||
# only calc duration if we're tot going to order by duration
|
||||
# only calc duration if we're going to order by duration
|
||||
if self.sampler_order != "duration":
|
||||
continue
|
||||
|
||||
|
@ -845,11 +845,6 @@ class Dataset(_Dataset):
|
|||
# might be better to decode => concat waveforms with silence in between => reencode
|
||||
# as you technically can't just append encodec sequences together like this without issues
|
||||
resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||
|
||||
"""
|
||||
resps = resps[:, :cfg.model.resp_levels]
|
||||
proms = proms[:, :cfg.model.resp_levels]
|
||||
"""
|
||||
|
||||
task = random.choice(self.tasks)
|
||||
|
||||
|
@ -1031,7 +1026,8 @@ class Dataset(_Dataset):
|
|||
text=text,
|
||||
proms=proms,
|
||||
resps=resps,
|
||||
text_string=text_string,
|
||||
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def head_(self, n):
|
||||
|
|
125
vall_e/demo.py
125
vall_e/demo.py
|
@ -26,7 +26,7 @@ from .config import cfg
|
|||
from .data import create_train_dataloader, create_val_dataloader
|
||||
from .emb.qnt import decode_to_file
|
||||
|
||||
from tqdm import tqdm
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
def encode(path):
|
||||
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
||||
|
@ -42,6 +42,7 @@ def main():
|
|||
parser.add_argument("--sample-from-dataset", action="store_true")
|
||||
parser.add_argument("--dataset-samples", type=int, default=0)
|
||||
parser.add_argument("--audio-path-root", type=str, default=None)
|
||||
parser.add_argument("--preamble", type=str, default=None)
|
||||
|
||||
parser.add_argument("--language", type=str, default="en")
|
||||
|
||||
|
@ -77,66 +78,28 @@ def main():
|
|||
if not args.demo_dir:
|
||||
args.demo_dir = Path("./data/demo/")
|
||||
|
||||
entries = []
|
||||
|
||||
# pull from provided samples
|
||||
sample_dir = args.demo_dir / "librispeech"
|
||||
if sample_dir.exists():
|
||||
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
||||
sources = ["ms_valle", "yourtts"]
|
||||
|
||||
# generate demo output
|
||||
for dir in tqdm(speakers, desc=f"Generating demo for speaker"):
|
||||
text = open(dir / "prompt.txt").read()
|
||||
prompt = dir / "prompt.wav"
|
||||
out_path = dir / "out" / "ours.wav"
|
||||
|
||||
entries.append((
|
||||
text,
|
||||
[ prompt, dir / "reference.wav", out_path ] + [ dir / "out" / f"{source}.wav" for source in sources ]
|
||||
))
|
||||
|
||||
if args.skip_existing and out_path.exists():
|
||||
continue
|
||||
|
||||
tts.inference(
|
||||
text=text,
|
||||
references=[prompt],
|
||||
language=args.language,
|
||||
out_path=out_path,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
seed=args.seed,
|
||||
tqdm=False,
|
||||
)
|
||||
|
||||
entries = [
|
||||
f'<tr><td>{text}</td>'+
|
||||
"".join( [
|
||||
f'<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||
for audio in audios
|
||||
] )+
|
||||
'</tr>'
|
||||
for text, audios in entries
|
||||
]
|
||||
if not args.preamble:
|
||||
args.preamble = "<br>".join([
|
||||
'Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>.',
|
||||
'LibriSpeech, comparing against the samples the original VALL-E demo sampled.',
|
||||
'I do not consider these to be state of the art, as the model does not follow close to the prompt as I would like for general speakers.',
|
||||
])
|
||||
|
||||
# read html template
|
||||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||
# create html table, in one messy line
|
||||
# replace in our template
|
||||
html = html.replace(r"${ENTRIES}", "\n".join(entries) )
|
||||
|
||||
samples = []
|
||||
# replace values in our template
|
||||
html = html.replace(r"${PREAMBLE}", args.preamble )
|
||||
|
||||
# pull from provided samples
|
||||
samples_dirs = {
|
||||
"librispeech": args.demo_dir / "librispeech",
|
||||
}
|
||||
|
||||
# pull from dataset samples
|
||||
if args.sample_from_dataset:
|
||||
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
||||
|
||||
print("Loading dataloader...")
|
||||
dataloader = create_train_dataloader()
|
||||
print("Loaded dataloader.")
|
||||
|
@ -144,35 +107,62 @@ def main():
|
|||
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
|
||||
|
||||
length = len( dataloader.dataset )
|
||||
for i in range( num ):
|
||||
for i in trange( num, desc="Sampling dataset for samples" ):
|
||||
idx = random.randint( 0, length )
|
||||
batch = dataloader.dataset[idx]
|
||||
|
||||
dir = args.demo_dir / "samples" / f'{i}'
|
||||
dir = args.demo_dir / "dataset" / f'{i}'
|
||||
|
||||
(dir / "out").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
text = batch["text_string"]
|
||||
metadata = batch["metadata"]
|
||||
|
||||
text = metadata["text"]
|
||||
language = metadata["language"]
|
||||
|
||||
prompt = dir / "prompt.wav"
|
||||
reference = dir / "reference.wav"
|
||||
out_path = dir / "out" / "ours.wav"
|
||||
|
||||
if args.skip_existing and out_path.exists():
|
||||
continue
|
||||
|
||||
open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text )
|
||||
open( dir / "language.txt", "w", encoding="utf-8" ).write( language )
|
||||
|
||||
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
|
||||
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
|
||||
|
||||
for k, sample_dir in samples_dirs.items():
|
||||
if not sample_dir.exists():
|
||||
continue
|
||||
|
||||
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
||||
sources = [ "ms_valle", "yourtts" ]
|
||||
|
||||
samples = []
|
||||
|
||||
# generate demo output
|
||||
for dir in tqdm(speakers, desc=f"Generating demo for speaker"):
|
||||
text = open(dir / "prompt.txt").read()
|
||||
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
|
||||
prompt = dir / "prompt.wav"
|
||||
out_path = dir / "out" / "ours.wav"
|
||||
|
||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else []
|
||||
|
||||
samples.append((
|
||||
text,
|
||||
[ prompt, reference, out_path ]
|
||||
[ prompt, dir / "reference.wav", out_path ] + extra_sources
|
||||
))
|
||||
|
||||
if args.skip_existing and out_path.exists():
|
||||
continue
|
||||
|
||||
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
|
||||
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
|
||||
|
||||
tts.inference(
|
||||
text=text,
|
||||
references=[prompt],
|
||||
language=args.language,
|
||||
language=language,
|
||||
out_path=out_path,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
|
@ -187,18 +177,21 @@ def main():
|
|||
tqdm=False,
|
||||
)
|
||||
|
||||
# collate entries into HTML
|
||||
samples = [
|
||||
f'<tr><td>{text}</td>'+
|
||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||
"".join( [
|
||||
f'<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||
for audio in audios
|
||||
] )+
|
||||
'</tr>'
|
||||
'\n\t\t\t</tr>'
|
||||
for text, audios in samples
|
||||
]
|
||||
|
||||
html = html.replace(r"${SAMPLES}", "\n".join(samples) )
|
||||
# write audio into template
|
||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||
|
||||
# write demo page
|
||||
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -144,6 +144,7 @@ class AR_NAR(Base):
|
|||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
else: # if p_rvq_levels == "auto":
|
||||
# makes higher levels less likely
|
||||
"""
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
p = random.random()
|
||||
|
@ -151,8 +152,17 @@ class AR_NAR(Base):
|
|||
if p < 1.0 / (2 ** i):
|
||||
index = i
|
||||
return int(index)
|
||||
"""
|
||||
|
||||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
# allow passing a specific distribution of RVQ levels
|
||||
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
||||
if not pool:
|
||||
lo, hi = quant_level_range[0], quant_level_range[1]
|
||||
for i in range( lo, hi ):
|
||||
rep = hi - i
|
||||
pool += [i] * rep
|
||||
|
||||
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
|
||||
|
||||
# these two are techinically equivalent if the audio embeddings handle things properly
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
|
|
|
@ -143,15 +143,25 @@ class NAR(Base):
|
|||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
else: # if p_rvq_levels == "auto":
|
||||
# makes higher levels less likely
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
p = random.random()
|
||||
for i in range(lo, hi):
|
||||
if p < 1.0 / (2 ** i):
|
||||
index = i
|
||||
return int(index)
|
||||
"""
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
p = random.random()
|
||||
for i in range(lo, hi):
|
||||
if p < 1.0 / (2 ** i):
|
||||
index = i
|
||||
return int(index)
|
||||
"""
|
||||
|
||||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
# allow passing a specific distribution of RVQ levels
|
||||
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
||||
if not pool:
|
||||
lo, hi = quant_level_range[0], quant_level_range[1]
|
||||
for i in range( lo, hi ):
|
||||
rep = hi - i
|
||||
pool += [i] * rep
|
||||
|
||||
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
|
||||
|
||||
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
|
||||
for i in range(batch_size):
|
||||
|
|
Loading…
Reference in New Issue
Block a user