diff --git a/vall_e/config.py b/vall_e/config.py index 267eb18..958361e 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -111,6 +111,7 @@ class _Config: class Dataset: training: list[Path] = field(default_factory=lambda: []) validation: list[Path] = field(default_factory=lambda: []) + noise: list[Path] = field(default_factory=lambda: []) temp: list[Path] = field(default_factory=lambda: []) @@ -393,6 +394,7 @@ class Trainer: @dataclass() class Inference: + normalize: bool = False # do NOT enable this unless you know exactly what you're doing use_vocos: bool = True @dataclass() @@ -473,6 +475,8 @@ try: if not cfg.dataset.use_hdf5: cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ] cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ] + + cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ] except Exception as e: pass diff --git a/vall_e/data.py b/vall_e/data.py index f87ad91..d746e2f 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -10,7 +10,7 @@ import random import torch from .config import cfg -from .emb.qnt import trim_random, repeat_extend_audio, merge_audio +from .emb.qnt import trim_random, repeat_extend_audio, merge_audio, decode_to_file from collections import defaultdict from functools import cache, cached_property @@ -77,7 +77,6 @@ def _get_phones(path, lang_marker="en"): split = content.split(" ") return [f""] + [ " " if not p else p for p in split ] + [f""] - def _interleaved_reorder(l, fn): groups = defaultdict(list) for e in l: @@ -216,11 +215,17 @@ class Dataset(_Dataset): def _get_task_symmap(self): return get_task_symmap() - def get_task_token( token ): - return torch.Tensor([[ self.tasks_symmap[f'<{token}>'] for _ in range(len(cfg.models.prom_levels)) ]], dtype=torch.int16) + def get_task_token( self, token ): + return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16) def sample_noise(self): - ... + paths = [] + print(cfg.dataset.noise) + for data_dir in cfg.dataset.noise: + paths.extend(data_dir.rglob("*.qnt.pt")) + + path = random.choice(paths) + return _load_quants(path) def sample_speakers(self, ignore=[]): choices = set(self.spkrs) - set(ignore) @@ -242,14 +247,13 @@ class Dataset(_Dataset): """ # shuffle it up a bit - offset = random.randint(-16, 16) - trim_length = int(cfg.dataset.prompt_duration * 75) + offset - total_qnt_length = 0 + prom_length = 0 + trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-16, 16) + for _ in range(cfg.dataset.max_prompts): path = random.choice(choices) if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) - #qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) else: qnt = _load_quants(path) @@ -258,12 +262,9 @@ class Dataset(_Dataset): qnt = trim_random( qnt, trim_length ) prom_list.append(qnt) - total_qnt_length += qnt.shape[0] + prom_length += qnt.shape[0] - if total_qnt_length >= trim_length: - break - - if random.random() > cfg.dataset.random_utterance: + if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance: break prom = torch.cat(prom_list) @@ -296,32 +297,47 @@ class Dataset(_Dataset): # text-to-speech if task == "tts": proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps - - """ # noise suppression || speech removal elif task == "ns" or task == "sr": # sample random noise noise = self.sample_noise() + # extend the noise to fill the target audio noise = repeat_extend_audio(noise, resps.shape[0]) # create the input prompt by merging the target audio with the noise - proms = merge_audio(resps, noise) + proms = merge_audio(resps, noise, scale=[1, 0.125]) # set the target to just be the noise if if task == "sr": resps = noise # prepend the task token proms = torch.cat( [self.get_task_token(task), proms] ) + + # set the text prompt to empty to train without a guided text prompt + if random.random() < 0.5: + text = torch.tensor([1, 2]).to(self.text_dtype) # target speech extraction elif task == "tse": # sample a random, clean, utterance for the target speaker clean_proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps # sample a random, clean utterance from a different speaker - other_proms = self.sample_prompts(self.sample_speaker(ignore=[spkr_name]), ignore="") + other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="") # overlay the random speaker over the target audio - noisy_proms = merge_audio(resps, other_proms) - # stitch together the promps + + smallest_size = min(resps.shape[0], other_proms.shape[0]) + if other_proms.shape[0] == smallest_size: + noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)] ) + noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] ) + else: + noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)] ) + noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] ) + + # stitch together the promps proms = torch.cat( [clean_proms, self.get_task_token(task), noisy_proms] ) - """ + + # set the text prompt to empty to train without a guided text prompt + if random.random() < 0.5: + text = torch.tensor([1, 2]).to(self.text_dtype) + # speech editing would require higher quality transcription data (phoneme level/word level) unfortunately # as I need to get a good clean point to trim into """ @@ -332,6 +348,29 @@ class Dataset(_Dataset): elif task == "nse": ... """ + + """ + # emulate SVC + # takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt + # targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip + + # NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical + # ahead-of-time dataset preparation of a shit ton of RVC clips might be the key. + # aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task. + elif task == "svc": + # sample a random, clean utterance for the target speaker + proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps + # sample a reference clip from a different speaker + ref_proms = self.sample_rvc(self.sample_speakers(ignore=[spkr_name])) + # + resps = + # stitch together the promps + proms = torch.cat( [proms, self.get_task_token(task), ref_proms] ) + + # set the text prompt to empty to train without a guided text prompt + if random.random() < 0.5: + text = torch.tensor([1, 2]).to(self.text_dtype) + """ return dict( @@ -608,23 +647,98 @@ if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Save trained model to path.") - parser.add_argument("--create-hdf5", action="store_true") + parser.add_argument("--task", type=str) args = parser.parse_args() - if args.create_hdf5: + task = args.task + + if args.task == "hdf5": create_dataset_hdf5() + elif args.task == "sample": + train_dl, subtrain_dl, val_dl = create_train_val_dataloader() - train_dl, subtrain_dl, val_dl = create_train_val_dataloader() + samples = { + "training": [ next(iter(train_dl)), next(iter(train_dl)) ], + "evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ], + "validation": [ next(iter(val_dl)), next(iter(val_dl)) ], + } - samples = { - "training": [ next(iter(train_dl)), next(iter(train_dl)) ], - "evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ], - "validation": [ next(iter(val_dl)), next(iter(val_dl)) ], - } + for k, v in samples.items(): + for i in range(len(v)): + del v[i]['proms'] + del v[i]['resps'] + print(f'{k}:', v) + """ + elif args.task == "tasks": + index = 0 + task = "ns" - for k, v in samples.items(): - for i in range(len(v)): - del v[i]['proms'] - del v[i]['resps'] - print(f'{k}:', v) + train_dataset, val_dataset = create_datasets() + train_dataset.task_symmap = get_task_symmap() + if cfg.dataset.sample_type == "speaker": + spkr_name = train_dataset.spkrs[index] + spkr_id = train_dataset.spkr_symmap[spkr_name] + path = random.choice([*set(train_dataset.paths_by_spkr_name[spkr_name])]) + else: + path = train_dataset.paths[index] + spkr_name = cfg.get_spkr(path) + spkr_id = train_dataset.spkr_symmap[spkr_name] + + if cfg.dataset.use_hdf5: + key = _get_hdf5_path(path) + text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(train_dataset.text_dtype) + resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) + else: + text = torch.tensor([*map(train_dataset.phone_symmap.get, _get_phones(path))]).to(train_dataset.text_dtype) + resps = _load_quants(path) + + noise = None + if task == "ns" or task == "sr": + # sample random noise + noise = train_dataset.sample_noise() + + decode_to_file( noise, "./.noise.wav", device="cpu" ) + + # extend the noise to fill the target audio + noise = repeat_extend_audio(noise, resps.shape[0]) + # create the input prompt by merging the target audio with the noise + proms = merge_audio(resps, noise, scale=[1, 0.125]) + # set the target to just be the noise if + if task == "sr": + resps = noise + # prepend the task token + proms = torch.cat( [train_dataset.get_task_token(task), proms] ) + + # set the text prompt to empty to train without a guided text prompt + if random.random() < 0.5: + text = torch.tensor([1, 2]).to(train_dataset.text_dtype) + # target speech extraction + elif task == "tse": + # sample a random, clean, utterance for the target speaker + clean_proms = train_dataset.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps + # sample a random, clean utterance from a different speaker + other_proms = train_dataset.sample_prompts(train_dataset.sample_speakers(ignore=[spkr_name]), ignore="") + # overlay the random speaker over the target audio + + smallest_size = min(resps.shape[0], other_proms.shape[0]) + if other_proms.shape[0] == smallest_size: + noisy_proms = merge_audio( resps[:smallest_size, :], other_proms ) + noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] ) + else: + noisy_proms = merge_audio( resps, other_proms[:smallest_size, :] ) + noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] ) + + # stitch together the promps + proms = torch.cat( [clean_proms, train_dataset.get_task_token(task), noisy_proms] ) + + # set the text prompt to empty to train without a guided text prompt + if random.random() < 0.5: + text = torch.tensor([1, 2]).to(train_dataset.text_dtype) + + decode_to_file( proms, "./.proms.wav", device="cpu" ) + decode_to_file( resps, "./.resps.wav", device="cpu" ) + + if noise is not None: + decode_to_file( noise, "./.noise-fill.wav", device="cpu" ) + """ \ No newline at end of file diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 0c74da6..40e115b 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -37,6 +37,7 @@ def _load_encodec_model(device="cuda"): model.set_target_bandwidth(bandwidth_id) model.bandwidth_id = bandwidth_id model.sample_rate = cfg.sample_rate + model.normalize = cfg.inference.normalize model.backend = "encodec" return model @@ -202,25 +203,32 @@ def repeat_extend_audio( qnt, target ): # merges two quantized audios together # I don't know if this works -def merge_audio( *args, device="cpu" ): +def merge_audio( *args, device="cpu", scale=[] ): qnts = [*args] decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ] + + if len(scale) == len(decoded): + for i in range(len(scale)): + decoded[i] = decoded[i] * scale[i] + combined = sum(decoded) / len(decoded) - return encode(combined, 24_000, device="cpu") + return encode(combined, 24_000, device="cpu")[0].t() def main(): parser = argparse.ArgumentParser() parser.add_argument("folder", type=Path) parser.add_argument("--suffix", default=".wav") + parser.add_argument("--device", default="cuda") args = parser.parse_args() + device = args.device paths = [*args.folder.rglob(f"*{args.suffix}")] for path in tqdm(paths): out_path = _replace_file_extension(path, ".qnt.pt") if out_path.exists(): continue - qnt = encode_from_file(path) + qnt = encode_from_file(path, device=device) torch.save(qnt.cpu(), out_path)