diff --git a/vall_e/config.py b/vall_e/config.py
index 267eb18..958361e 100755
--- a/vall_e/config.py
+++ b/vall_e/config.py
@@ -111,6 +111,7 @@ class _Config:
class Dataset:
training: list[Path] = field(default_factory=lambda: [])
validation: list[Path] = field(default_factory=lambda: [])
+ noise: list[Path] = field(default_factory=lambda: [])
temp: list[Path] = field(default_factory=lambda: [])
@@ -393,6 +394,7 @@ class Trainer:
@dataclass()
class Inference:
+ normalize: bool = False # do NOT enable this unless you know exactly what you're doing
use_vocos: bool = True
@dataclass()
@@ -473,6 +475,8 @@ try:
if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
+
+ cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
except Exception as e:
pass
diff --git a/vall_e/data.py b/vall_e/data.py
index f87ad91..d746e2f 100755
--- a/vall_e/data.py
+++ b/vall_e/data.py
@@ -10,7 +10,7 @@ import random
import torch
from .config import cfg
-from .emb.qnt import trim_random, repeat_extend_audio, merge_audio
+from .emb.qnt import trim_random, repeat_extend_audio, merge_audio, decode_to_file
from collections import defaultdict
from functools import cache, cached_property
@@ -77,7 +77,6 @@ def _get_phones(path, lang_marker="en"):
split = content.split(" ")
return [f""] + [ " " if not p else p for p in split ] + [f""]
-
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
for e in l:
@@ -216,11 +215,17 @@ class Dataset(_Dataset):
def _get_task_symmap(self):
return get_task_symmap()
- def get_task_token( token ):
- return torch.Tensor([[ self.tasks_symmap[f'<{token}>'] for _ in range(len(cfg.models.prom_levels)) ]], dtype=torch.int16)
+ def get_task_token( self, token ):
+ return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16)
def sample_noise(self):
- ...
+ paths = []
+ print(cfg.dataset.noise)
+ for data_dir in cfg.dataset.noise:
+ paths.extend(data_dir.rglob("*.qnt.pt"))
+
+ path = random.choice(paths)
+ return _load_quants(path)
def sample_speakers(self, ignore=[]):
choices = set(self.spkrs) - set(ignore)
@@ -242,14 +247,13 @@ class Dataset(_Dataset):
"""
# shuffle it up a bit
- offset = random.randint(-16, 16)
- trim_length = int(cfg.dataset.prompt_duration * 75) + offset
- total_qnt_length = 0
+ prom_length = 0
+ trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-16, 16)
+
for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices)
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
- #qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
else:
qnt = _load_quants(path)
@@ -258,12 +262,9 @@ class Dataset(_Dataset):
qnt = trim_random( qnt, trim_length )
prom_list.append(qnt)
- total_qnt_length += qnt.shape[0]
+ prom_length += qnt.shape[0]
- if total_qnt_length >= trim_length:
- break
-
- if random.random() > cfg.dataset.random_utterance:
+ if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
break
prom = torch.cat(prom_list)
@@ -296,32 +297,47 @@ class Dataset(_Dataset):
# text-to-speech
if task == "tts":
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
-
- """
# noise suppression || speech removal
elif task == "ns" or task == "sr":
# sample random noise
noise = self.sample_noise()
+
# extend the noise to fill the target audio
noise = repeat_extend_audio(noise, resps.shape[0])
# create the input prompt by merging the target audio with the noise
- proms = merge_audio(resps, noise)
+ proms = merge_audio(resps, noise, scale=[1, 0.125])
# set the target to just be the noise if
if task == "sr":
resps = noise
# prepend the task token
proms = torch.cat( [self.get_task_token(task), proms] )
+
+ # set the text prompt to empty to train without a guided text prompt
+ if random.random() < 0.5:
+ text = torch.tensor([1, 2]).to(self.text_dtype)
# target speech extraction
elif task == "tse":
# sample a random, clean, utterance for the target speaker
clean_proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# sample a random, clean utterance from a different speaker
- other_proms = self.sample_prompts(self.sample_speaker(ignore=[spkr_name]), ignore="")
+ other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
# overlay the random speaker over the target audio
- noisy_proms = merge_audio(resps, other_proms)
- # stitch together the promps
+
+ smallest_size = min(resps.shape[0], other_proms.shape[0])
+ if other_proms.shape[0] == smallest_size:
+ noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)] )
+ noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
+ else:
+ noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)] )
+ noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
+
+ # stitch together the promps
proms = torch.cat( [clean_proms, self.get_task_token(task), noisy_proms] )
- """
+
+ # set the text prompt to empty to train without a guided text prompt
+ if random.random() < 0.5:
+ text = torch.tensor([1, 2]).to(self.text_dtype)
+
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
"""
@@ -332,6 +348,29 @@ class Dataset(_Dataset):
elif task == "nse":
...
"""
+
+ """
+ # emulate SVC
+ # takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
+ # targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip
+
+ # NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical
+ # ahead-of-time dataset preparation of a shit ton of RVC clips might be the key.
+ # aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task.
+ elif task == "svc":
+ # sample a random, clean utterance for the target speaker
+ proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
+ # sample a reference clip from a different speaker
+ ref_proms = self.sample_rvc(self.sample_speakers(ignore=[spkr_name]))
+ #
+ resps =
+ # stitch together the promps
+ proms = torch.cat( [proms, self.get_task_token(task), ref_proms] )
+
+ # set the text prompt to empty to train without a guided text prompt
+ if random.random() < 0.5:
+ text = torch.tensor([1, 2]).to(self.text_dtype)
+ """
return dict(
@@ -608,23 +647,98 @@ if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Save trained model to path.")
- parser.add_argument("--create-hdf5", action="store_true")
+ parser.add_argument("--task", type=str)
args = parser.parse_args()
- if args.create_hdf5:
+ task = args.task
+
+ if args.task == "hdf5":
create_dataset_hdf5()
+ elif args.task == "sample":
+ train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
- train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
+ samples = {
+ "training": [ next(iter(train_dl)), next(iter(train_dl)) ],
+ "evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
+ "validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
+ }
- samples = {
- "training": [ next(iter(train_dl)), next(iter(train_dl)) ],
- "evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
- "validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
- }
+ for k, v in samples.items():
+ for i in range(len(v)):
+ del v[i]['proms']
+ del v[i]['resps']
+ print(f'{k}:', v)
+ """
+ elif args.task == "tasks":
+ index = 0
+ task = "ns"
- for k, v in samples.items():
- for i in range(len(v)):
- del v[i]['proms']
- del v[i]['resps']
- print(f'{k}:', v)
+ train_dataset, val_dataset = create_datasets()
+ train_dataset.task_symmap = get_task_symmap()
+ if cfg.dataset.sample_type == "speaker":
+ spkr_name = train_dataset.spkrs[index]
+ spkr_id = train_dataset.spkr_symmap[spkr_name]
+ path = random.choice([*set(train_dataset.paths_by_spkr_name[spkr_name])])
+ else:
+ path = train_dataset.paths[index]
+ spkr_name = cfg.get_spkr(path)
+ spkr_id = train_dataset.spkr_symmap[spkr_name]
+
+ if cfg.dataset.use_hdf5:
+ key = _get_hdf5_path(path)
+ text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(train_dataset.text_dtype)
+ resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
+ else:
+ text = torch.tensor([*map(train_dataset.phone_symmap.get, _get_phones(path))]).to(train_dataset.text_dtype)
+ resps = _load_quants(path)
+
+ noise = None
+ if task == "ns" or task == "sr":
+ # sample random noise
+ noise = train_dataset.sample_noise()
+
+ decode_to_file( noise, "./.noise.wav", device="cpu" )
+
+ # extend the noise to fill the target audio
+ noise = repeat_extend_audio(noise, resps.shape[0])
+ # create the input prompt by merging the target audio with the noise
+ proms = merge_audio(resps, noise, scale=[1, 0.125])
+ # set the target to just be the noise if
+ if task == "sr":
+ resps = noise
+ # prepend the task token
+ proms = torch.cat( [train_dataset.get_task_token(task), proms] )
+
+ # set the text prompt to empty to train without a guided text prompt
+ if random.random() < 0.5:
+ text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
+ # target speech extraction
+ elif task == "tse":
+ # sample a random, clean, utterance for the target speaker
+ clean_proms = train_dataset.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
+ # sample a random, clean utterance from a different speaker
+ other_proms = train_dataset.sample_prompts(train_dataset.sample_speakers(ignore=[spkr_name]), ignore="")
+ # overlay the random speaker over the target audio
+
+ smallest_size = min(resps.shape[0], other_proms.shape[0])
+ if other_proms.shape[0] == smallest_size:
+ noisy_proms = merge_audio( resps[:smallest_size, :], other_proms )
+ noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
+ else:
+ noisy_proms = merge_audio( resps, other_proms[:smallest_size, :] )
+ noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
+
+ # stitch together the promps
+ proms = torch.cat( [clean_proms, train_dataset.get_task_token(task), noisy_proms] )
+
+ # set the text prompt to empty to train without a guided text prompt
+ if random.random() < 0.5:
+ text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
+
+ decode_to_file( proms, "./.proms.wav", device="cpu" )
+ decode_to_file( resps, "./.resps.wav", device="cpu" )
+
+ if noise is not None:
+ decode_to_file( noise, "./.noise-fill.wav", device="cpu" )
+ """
\ No newline at end of file
diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py
index 0c74da6..40e115b 100755
--- a/vall_e/emb/qnt.py
+++ b/vall_e/emb/qnt.py
@@ -37,6 +37,7 @@ def _load_encodec_model(device="cuda"):
model.set_target_bandwidth(bandwidth_id)
model.bandwidth_id = bandwidth_id
model.sample_rate = cfg.sample_rate
+ model.normalize = cfg.inference.normalize
model.backend = "encodec"
return model
@@ -202,25 +203,32 @@ def repeat_extend_audio( qnt, target ):
# merges two quantized audios together
# I don't know if this works
-def merge_audio( *args, device="cpu" ):
+def merge_audio( *args, device="cpu", scale=[] ):
qnts = [*args]
decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ]
+
+ if len(scale) == len(decoded):
+ for i in range(len(scale)):
+ decoded[i] = decoded[i] * scale[i]
+
combined = sum(decoded) / len(decoded)
- return encode(combined, 24_000, device="cpu")
+ return encode(combined, 24_000, device="cpu")[0].t()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("folder", type=Path)
parser.add_argument("--suffix", default=".wav")
+ parser.add_argument("--device", default="cuda")
args = parser.parse_args()
+ device = args.device
paths = [*args.folder.rglob(f"*{args.suffix}")]
for path in tqdm(paths):
out_path = _replace_file_extension(path, ".qnt.pt")
if out_path.exists():
continue
- qnt = encode_from_file(path)
+ qnt = encode_from_file(path, device=device)
torch.save(qnt.cpu(), out_path)