cleaned up demo page creation, added option to pass in RVQ level sampling distribution for training
This commit is contained in:
parent
ba7ee8c0ee
commit
e19aa643a6
|
@ -4,7 +4,7 @@
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<h1>VALL-E Demo</h1>
|
<h1>VALL-E Demo</h1>
|
||||||
<p>Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>. I do not consider these to be state of the art. Below are samples from LibriSpeech, comparing against the samples the original VALL-E demo sampled.</p>
|
<p>${PREAMBLE}</p>
|
||||||
<table>
|
<table>
|
||||||
<tr>
|
<tr>
|
||||||
<th>Text</th>
|
<th>Text</th>
|
||||||
|
@ -13,18 +13,15 @@
|
||||||
<th>Our VALL-E</th>
|
<th>Our VALL-E</th>
|
||||||
<th>Original VALL-E</th>
|
<th>Original VALL-E</th>
|
||||||
<th>YourTTS</th>
|
<th>YourTTS</th>
|
||||||
</tr>
|
</tr>${LIBRISPEECH_SAMPLES}
|
||||||
${ENTRIES}
|
|
||||||
</table>
|
</table>
|
||||||
<p>Below are some extra samples.</p>
|
|
||||||
<table>
|
<table>
|
||||||
<tr>
|
<tr>
|
||||||
<th>Text</th>
|
<th>Text</th>
|
||||||
<th>Prompt</th>
|
<th>Prompt</th>
|
||||||
<th>Ground Truth</th>
|
<th>Ground Truth</th>
|
||||||
<th>Our VALL-E</th>
|
<th>Our VALL-E</th>
|
||||||
</tr>
|
</tr>${DATASET_SAMPLES}
|
||||||
${SAMPLES}
|
|
||||||
</table>
|
</table>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
|
@ -209,7 +209,7 @@ class ModelExperimentalSettings:
|
||||||
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||||
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
|
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
|
||||||
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
||||||
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
||||||
|
|
|
@ -486,7 +486,7 @@ class Dataset(_Dataset):
|
||||||
duration = self.duration_map[path]
|
duration = self.duration_map[path]
|
||||||
self.duration += duration
|
self.duration += duration
|
||||||
|
|
||||||
# only calc duration if we're tot going to order by duration
|
# only calc duration if we're going to order by duration
|
||||||
if self.sampler_order != "duration":
|
if self.sampler_order != "duration":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -846,11 +846,6 @@ class Dataset(_Dataset):
|
||||||
# as you technically can't just append encodec sequences together like this without issues
|
# as you technically can't just append encodec sequences together like this without issues
|
||||||
resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||||
|
|
||||||
"""
|
|
||||||
resps = resps[:, :cfg.model.resp_levels]
|
|
||||||
proms = proms[:, :cfg.model.resp_levels]
|
|
||||||
"""
|
|
||||||
|
|
||||||
task = random.choice(self.tasks)
|
task = random.choice(self.tasks)
|
||||||
|
|
||||||
if f'<{task}>' not in self.task_symmap:
|
if f'<{task}>' not in self.task_symmap:
|
||||||
|
@ -1031,7 +1026,8 @@ class Dataset(_Dataset):
|
||||||
text=text,
|
text=text,
|
||||||
proms=proms,
|
proms=proms,
|
||||||
resps=resps,
|
resps=resps,
|
||||||
text_string=text_string,
|
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def head_(self, n):
|
def head_(self, n):
|
||||||
|
|
125
vall_e/demo.py
125
vall_e/demo.py
|
@ -26,7 +26,7 @@ from .config import cfg
|
||||||
from .data import create_train_dataloader, create_val_dataloader
|
from .data import create_train_dataloader, create_val_dataloader
|
||||||
from .emb.qnt import decode_to_file
|
from .emb.qnt import decode_to_file
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
def encode(path):
|
def encode(path):
|
||||||
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
||||||
|
@ -42,6 +42,7 @@ def main():
|
||||||
parser.add_argument("--sample-from-dataset", action="store_true")
|
parser.add_argument("--sample-from-dataset", action="store_true")
|
||||||
parser.add_argument("--dataset-samples", type=int, default=0)
|
parser.add_argument("--dataset-samples", type=int, default=0)
|
||||||
parser.add_argument("--audio-path-root", type=str, default=None)
|
parser.add_argument("--audio-path-root", type=str, default=None)
|
||||||
|
parser.add_argument("--preamble", type=str, default=None)
|
||||||
|
|
||||||
parser.add_argument("--language", type=str, default="en")
|
parser.add_argument("--language", type=str, default="en")
|
||||||
|
|
||||||
|
@ -77,66 +78,28 @@ def main():
|
||||||
if not args.demo_dir:
|
if not args.demo_dir:
|
||||||
args.demo_dir = Path("./data/demo/")
|
args.demo_dir = Path("./data/demo/")
|
||||||
|
|
||||||
entries = []
|
if not args.preamble:
|
||||||
|
args.preamble = "<br>".join([
|
||||||
# pull from provided samples
|
'Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>.',
|
||||||
sample_dir = args.demo_dir / "librispeech"
|
'LibriSpeech, comparing against the samples the original VALL-E demo sampled.',
|
||||||
if sample_dir.exists():
|
'I do not consider these to be state of the art, as the model does not follow close to the prompt as I would like for general speakers.',
|
||||||
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
])
|
||||||
sources = ["ms_valle", "yourtts"]
|
|
||||||
|
|
||||||
# generate demo output
|
|
||||||
for dir in tqdm(speakers, desc=f"Generating demo for speaker"):
|
|
||||||
text = open(dir / "prompt.txt").read()
|
|
||||||
prompt = dir / "prompt.wav"
|
|
||||||
out_path = dir / "out" / "ours.wav"
|
|
||||||
|
|
||||||
entries.append((
|
|
||||||
text,
|
|
||||||
[ prompt, dir / "reference.wav", out_path ] + [ dir / "out" / f"{source}.wav" for source in sources ]
|
|
||||||
))
|
|
||||||
|
|
||||||
if args.skip_existing and out_path.exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
tts.inference(
|
|
||||||
text=text,
|
|
||||||
references=[prompt],
|
|
||||||
language=args.language,
|
|
||||||
out_path=out_path,
|
|
||||||
input_prompt_length=args.input_prompt_length,
|
|
||||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
|
||||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
|
||||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
|
||||||
top_p=args.top_p, top_k=args.top_k,
|
|
||||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
|
||||||
length_penalty=args.length_penalty,
|
|
||||||
beam_width=args.beam_width,
|
|
||||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
|
||||||
seed=args.seed,
|
|
||||||
tqdm=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
entries = [
|
|
||||||
f'<tr><td>{text}</td>'+
|
|
||||||
"".join( [
|
|
||||||
f'<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
|
||||||
for audio in audios
|
|
||||||
] )+
|
|
||||||
'</tr>'
|
|
||||||
for text, audios in entries
|
|
||||||
]
|
|
||||||
|
|
||||||
# read html template
|
# read html template
|
||||||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||||
# create html table, in one messy line
|
|
||||||
# replace in our template
|
|
||||||
html = html.replace(r"${ENTRIES}", "\n".join(entries) )
|
|
||||||
|
|
||||||
samples = []
|
# replace values in our template
|
||||||
|
html = html.replace(r"${PREAMBLE}", args.preamble )
|
||||||
|
|
||||||
|
# pull from provided samples
|
||||||
|
samples_dirs = {
|
||||||
|
"librispeech": args.demo_dir / "librispeech",
|
||||||
|
}
|
||||||
|
|
||||||
# pull from dataset samples
|
# pull from dataset samples
|
||||||
if args.sample_from_dataset:
|
if args.sample_from_dataset:
|
||||||
|
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
||||||
|
|
||||||
print("Loading dataloader...")
|
print("Loading dataloader...")
|
||||||
dataloader = create_train_dataloader()
|
dataloader = create_train_dataloader()
|
||||||
print("Loaded dataloader.")
|
print("Loaded dataloader.")
|
||||||
|
@ -144,35 +107,62 @@ def main():
|
||||||
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
|
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
|
||||||
|
|
||||||
length = len( dataloader.dataset )
|
length = len( dataloader.dataset )
|
||||||
for i in range( num ):
|
for i in trange( num, desc="Sampling dataset for samples" ):
|
||||||
idx = random.randint( 0, length )
|
idx = random.randint( 0, length )
|
||||||
batch = dataloader.dataset[idx]
|
batch = dataloader.dataset[idx]
|
||||||
|
|
||||||
dir = args.demo_dir / "samples" / f'{i}'
|
dir = args.demo_dir / "dataset" / f'{i}'
|
||||||
|
|
||||||
(dir / "out").mkdir(parents=True, exist_ok=True)
|
(dir / "out").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
text = batch["text_string"]
|
metadata = batch["metadata"]
|
||||||
|
|
||||||
|
text = metadata["text"]
|
||||||
|
language = metadata["language"]
|
||||||
|
|
||||||
prompt = dir / "prompt.wav"
|
prompt = dir / "prompt.wav"
|
||||||
reference = dir / "reference.wav"
|
reference = dir / "reference.wav"
|
||||||
out_path = dir / "out" / "ours.wav"
|
out_path = dir / "out" / "ours.wav"
|
||||||
|
|
||||||
|
if args.skip_existing and out_path.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text )
|
||||||
|
open( dir / "language.txt", "w", encoding="utf-8" ).write( language )
|
||||||
|
|
||||||
|
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
|
||||||
|
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
|
||||||
|
|
||||||
|
for k, sample_dir in samples_dirs.items():
|
||||||
|
if not sample_dir.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
||||||
|
sources = [ "ms_valle", "yourtts" ]
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
|
||||||
|
# generate demo output
|
||||||
|
for dir in tqdm(speakers, desc=f"Generating demo for speaker"):
|
||||||
|
text = open(dir / "prompt.txt").read()
|
||||||
|
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
|
||||||
|
prompt = dir / "prompt.wav"
|
||||||
|
out_path = dir / "out" / "ours.wav"
|
||||||
|
|
||||||
|
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else []
|
||||||
|
|
||||||
samples.append((
|
samples.append((
|
||||||
text,
|
text,
|
||||||
[ prompt, reference, out_path ]
|
[ prompt, dir / "reference.wav", out_path ] + extra_sources
|
||||||
))
|
))
|
||||||
|
|
||||||
if args.skip_existing and out_path.exists():
|
if args.skip_existing and out_path.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
|
|
||||||
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
|
|
||||||
|
|
||||||
tts.inference(
|
tts.inference(
|
||||||
text=text,
|
text=text,
|
||||||
references=[prompt],
|
references=[prompt],
|
||||||
language=args.language,
|
language=language,
|
||||||
out_path=out_path,
|
out_path=out_path,
|
||||||
input_prompt_length=args.input_prompt_length,
|
input_prompt_length=args.input_prompt_length,
|
||||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||||
|
@ -187,18 +177,21 @@ def main():
|
||||||
tqdm=False,
|
tqdm=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# collate entries into HTML
|
||||||
samples = [
|
samples = [
|
||||||
f'<tr><td>{text}</td>'+
|
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||||
"".join( [
|
"".join( [
|
||||||
f'<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||||
for audio in audios
|
for audio in audios
|
||||||
] )+
|
] )+
|
||||||
'</tr>'
|
'\n\t\t\t</tr>'
|
||||||
for text, audios in samples
|
for text, audios in samples
|
||||||
]
|
]
|
||||||
|
|
||||||
html = html.replace(r"${SAMPLES}", "\n".join(samples) )
|
# write audio into template
|
||||||
|
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||||
|
|
||||||
|
# write demo page
|
||||||
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
|
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -144,6 +144,7 @@ class AR_NAR(Base):
|
||||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||||
else: # if p_rvq_levels == "auto":
|
else: # if p_rvq_levels == "auto":
|
||||||
# makes higher levels less likely
|
# makes higher levels less likely
|
||||||
|
"""
|
||||||
def generate( lo=0, hi=8 ):
|
def generate( lo=0, hi=8 ):
|
||||||
index = lo
|
index = lo
|
||||||
p = random.random()
|
p = random.random()
|
||||||
|
@ -151,8 +152,17 @@ class AR_NAR(Base):
|
||||||
if p < 1.0 / (2 ** i):
|
if p < 1.0 / (2 ** i):
|
||||||
index = i
|
index = i
|
||||||
return int(index)
|
return int(index)
|
||||||
|
"""
|
||||||
|
|
||||||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
# allow passing a specific distribution of RVQ levels
|
||||||
|
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
||||||
|
if not pool:
|
||||||
|
lo, hi = quant_level_range[0], quant_level_range[1]
|
||||||
|
for i in range( lo, hi ):
|
||||||
|
rep = hi - i
|
||||||
|
pool += [i] * rep
|
||||||
|
|
||||||
|
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
|
||||||
|
|
||||||
# these two are techinically equivalent if the audio embeddings handle things properly
|
# these two are techinically equivalent if the audio embeddings handle things properly
|
||||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
|
|
|
@ -143,6 +143,7 @@ class NAR(Base):
|
||||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||||
else: # if p_rvq_levels == "auto":
|
else: # if p_rvq_levels == "auto":
|
||||||
# makes higher levels less likely
|
# makes higher levels less likely
|
||||||
|
"""
|
||||||
def generate( lo=0, hi=8 ):
|
def generate( lo=0, hi=8 ):
|
||||||
index = lo
|
index = lo
|
||||||
p = random.random()
|
p = random.random()
|
||||||
|
@ -150,8 +151,17 @@ class NAR(Base):
|
||||||
if p < 1.0 / (2 ** i):
|
if p < 1.0 / (2 ** i):
|
||||||
index = i
|
index = i
|
||||||
return int(index)
|
return int(index)
|
||||||
|
"""
|
||||||
|
|
||||||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
# allow passing a specific distribution of RVQ levels
|
||||||
|
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
||||||
|
if not pool:
|
||||||
|
lo, hi = quant_level_range[0], quant_level_range[1]
|
||||||
|
for i in range( lo, hi ):
|
||||||
|
rep = hi - i
|
||||||
|
pool += [i] * rep
|
||||||
|
|
||||||
|
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
|
||||||
|
|
||||||
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
|
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user