Below are some samples from my VALL-E implementation: https://git.ecker.tech/mrq/vall-e/. I do not consider these to be state of the art. Below are samples from LibriSpeech, comparing against the samples the original VALL-E demo sampled.
-
-
-
Text
-
Prompt
-
Ground Truth
-
Our VALL-E
-
Original VALL-E
-
YourTTS
-
- ${ENTRIES}
-
-
Below are some extra samples.
-
-
-
Text
-
Prompt
-
Ground Truth
-
Our VALL-E
-
- ${SAMPLES}
-
-
+
+
+
+
+
VALL-E Demo
+
${PREAMBLE}
+
+
+
Text
+
Prompt
+
Ground Truth
+
Our VALL-E
+
Original VALL-E
+
YourTTS
+
${LIBRISPEECH_SAMPLES}
+
+
+
+
Text
+
Prompt
+
Ground Truth
+
Our VALL-E
+
${DATASET_SAMPLES}
+
+
\ No newline at end of file
diff --git a/vall_e/config.py b/vall_e/config.py
index 721b05d..ced46cb 100755
--- a/vall_e/config.py
+++ b/vall_e/config.py
@@ -209,7 +209,7 @@ class ModelExperimentalSettings:
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
kv_heads: int = 0 # MHA or GQA (for supported backends)
- p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
+ p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
diff --git a/vall_e/data.py b/vall_e/data.py
index b57ba6d..ec13ea0 100755
--- a/vall_e/data.py
+++ b/vall_e/data.py
@@ -486,7 +486,7 @@ class Dataset(_Dataset):
duration = self.duration_map[path]
self.duration += duration
- # only calc duration if we're tot going to order by duration
+ # only calc duration if we're going to order by duration
if self.sampler_order != "duration":
continue
@@ -845,11 +845,6 @@ class Dataset(_Dataset):
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
-
- """
- resps = resps[:, :cfg.model.resp_levels]
- proms = proms[:, :cfg.model.resp_levels]
- """
task = random.choice(self.tasks)
@@ -1031,7 +1026,8 @@ class Dataset(_Dataset):
text=text,
proms=proms,
resps=resps,
- text_string=text_string,
+
+ metadata=metadata,
)
def head_(self, n):
diff --git a/vall_e/demo.py b/vall_e/demo.py
index 08d2dca..1e59ec5 100644
--- a/vall_e/demo.py
+++ b/vall_e/demo.py
@@ -26,7 +26,7 @@ from .config import cfg
from .data import create_train_dataloader, create_val_dataloader
from .emb.qnt import decode_to_file
-from tqdm import tqdm
+from tqdm import tqdm, trange
def encode(path):
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
@@ -42,6 +42,7 @@ def main():
parser.add_argument("--sample-from-dataset", action="store_true")
parser.add_argument("--dataset-samples", type=int, default=0)
parser.add_argument("--audio-path-root", type=str, default=None)
+ parser.add_argument("--preamble", type=str, default=None)
parser.add_argument("--language", type=str, default="en")
@@ -77,66 +78,28 @@ def main():
if not args.demo_dir:
args.demo_dir = Path("./data/demo/")
- entries = []
-
- # pull from provided samples
- sample_dir = args.demo_dir / "librispeech"
- if sample_dir.exists():
- speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
- sources = ["ms_valle", "yourtts"]
-
- # generate demo output
- for dir in tqdm(speakers, desc=f"Generating demo for speaker"):
- text = open(dir / "prompt.txt").read()
- prompt = dir / "prompt.wav"
- out_path = dir / "out" / "ours.wav"
-
- entries.append((
- text,
- [ prompt, dir / "reference.wav", out_path ] + [ dir / "out" / f"{source}.wav" for source in sources ]
- ))
-
- if args.skip_existing and out_path.exists():
- continue
-
- tts.inference(
- text=text,
- references=[prompt],
- language=args.language,
- out_path=out_path,
- input_prompt_length=args.input_prompt_length,
- max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
- ar_temp=args.ar_temp, nar_temp=args.nar_temp,
- min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
- top_p=args.top_p, top_k=args.top_k,
- repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
- length_penalty=args.length_penalty,
- beam_width=args.beam_width,
- mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
- seed=args.seed,
- tqdm=False,
- )
-
- entries = [
- f'
{text}
'+
- "".join( [
- f'
'
- for audio in audios
- ] )+
- '
'
- for text, audios in entries
- ]
+ if not args.preamble:
+ args.preamble = " ".join([
+ 'Below are some samples from my VALL-E implementation: https://git.ecker.tech/mrq/vall-e/.',
+ 'LibriSpeech, comparing against the samples the original VALL-E demo sampled.',
+ 'I do not consider these to be state of the art, as the model does not follow close to the prompt as I would like for general speakers.',
+ ])
# read html template
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
- # create html table, in one messy line
- # replace in our template
- html = html.replace(r"${ENTRIES}", "\n".join(entries) )
- samples = []
+ # replace values in our template
+ html = html.replace(r"${PREAMBLE}", args.preamble )
+
+ # pull from provided samples
+ samples_dirs = {
+ "librispeech": args.demo_dir / "librispeech",
+ }
# pull from dataset samples
if args.sample_from_dataset:
+ samples_dirs["dataset"] = args.demo_dir / "dataset"
+
print("Loading dataloader...")
dataloader = create_train_dataloader()
print("Loaded dataloader.")
@@ -144,35 +107,62 @@ def main():
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
length = len( dataloader.dataset )
- for i in range( num ):
+ for i in trange( num, desc="Sampling dataset for samples" ):
idx = random.randint( 0, length )
batch = dataloader.dataset[idx]
- dir = args.demo_dir / "samples" / f'{i}'
+ dir = args.demo_dir / "dataset" / f'{i}'
(dir / "out").mkdir(parents=True, exist_ok=True)
- text = batch["text_string"]
+ metadata = batch["metadata"]
+
+ text = metadata["text"]
+ language = metadata["language"]
prompt = dir / "prompt.wav"
reference = dir / "reference.wav"
out_path = dir / "out" / "ours.wav"
+ if args.skip_existing and out_path.exists():
+ continue
+
+ open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text )
+ open( dir / "language.txt", "w", encoding="utf-8" ).write( language )
+
+ decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
+ decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
+
+ for k, sample_dir in samples_dirs.items():
+ if not sample_dir.exists():
+ continue
+
+ speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
+ sources = [ "ms_valle", "yourtts" ]
+
+ samples = []
+
+ # generate demo output
+ for dir in tqdm(speakers, desc=f"Generating demo for speaker"):
+ text = open(dir / "prompt.txt").read()
+ language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
+ prompt = dir / "prompt.wav"
+ out_path = dir / "out" / "ours.wav"
+
+ extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else []
+
samples.append((
text,
- [ prompt, reference, out_path ]
+ [ prompt, dir / "reference.wav", out_path ] + extra_sources
))
if args.skip_existing and out_path.exists():
continue
- decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
- decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
-
tts.inference(
text=text,
references=[prompt],
- language=args.language,
+ language=language,
out_path=out_path,
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
@@ -187,18 +177,21 @@ def main():
tqdm=False,
)
+ # collate entries into HTML
samples = [
- f'
{text}
'+
+ f'\n\t\t\t
\n\t\t\t\t
{text}
'+
"".join( [
- f'
'
+ f'\n\t\t\t\t
'
for audio in audios
] )+
- '
'
+ '\n\t\t\t'
for text, audios in samples
]
- html = html.replace(r"${SAMPLES}", "\n".join(samples) )
+ # write audio into template
+ html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
+ # write demo page
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
if __name__ == "__main__":
diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py
index a598611..a9b8b76 100644
--- a/vall_e/models/ar_nar.py
+++ b/vall_e/models/ar_nar.py
@@ -144,6 +144,7 @@ class AR_NAR(Base):
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
else: # if p_rvq_levels == "auto":
# makes higher levels less likely
+ """
def generate( lo=0, hi=8 ):
index = lo
p = random.random()
@@ -151,8 +152,17 @@ class AR_NAR(Base):
if p < 1.0 / (2 ** i):
index = i
return int(index)
+ """
- quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
+ # allow passing a specific distribution of RVQ levels
+ pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
+ if not pool:
+ lo, hi = quant_level_range[0], quant_level_range[1]
+ for i in range( lo, hi ):
+ rep = hi - i
+ pool += [i] * rep
+
+ quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
# these two are techinically equivalent if the audio embeddings handle things properly
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py
index a251340..571ccdd 100644
--- a/vall_e/models/nar.py
+++ b/vall_e/models/nar.py
@@ -143,15 +143,25 @@ class NAR(Base):
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
else: # if p_rvq_levels == "auto":
# makes higher levels less likely
- def generate( lo=0, hi=8 ):
- index = lo
- p = random.random()
- for i in range(lo, hi):
- if p < 1.0 / (2 ** i):
- index = i
- return int(index)
+ """
+ def generate( lo=0, hi=8 ):
+ index = lo
+ p = random.random()
+ for i in range(lo, hi):
+ if p < 1.0 / (2 ** i):
+ index = i
+ return int(index)
+ """
- quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
+ # allow passing a specific distribution of RVQ levels
+ pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
+ if not pool:
+ lo, hi = quant_level_range[0], quant_level_range[1]
+ for i in range( lo, hi ):
+ rep = hi - i
+ pool += [i] * rep
+
+ quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
for i in range(batch_size):