diff --git a/vall_e/config.py b/vall_e/config.py index 640859d..267eb18 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -311,8 +311,8 @@ class DeepSpeed: "quantization_period": 0 }, "modules": [ - "blocks", - "retnet", + "blocks", # for transformer-based models + "retnet", # for RetNets-based models ] } } diff --git a/vall_e/data.py b/vall_e/data.py index 9bb507c..f87ad91 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -10,6 +10,7 @@ import random import torch from .config import cfg +from .emb.qnt import trim_random, repeat_extend_audio, merge_audio from collections import defaultdict from functools import cache, cached_property @@ -33,6 +34,20 @@ def get_phone_symmap(): symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} return symmap +def get_task_symmap(): + start = 1024 + symmap = { + "": -100, + "": start + 0, + "": start + 1, + "": start + 2, + "": start + 3, + "": start + 4, + "": start + 5, + "": start + 6, + } + return symmap + def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) @@ -112,6 +127,7 @@ class Dataset(_Dataset): paths, phone_symmap=None, spkr_symmap=None, + task_symmap=None, min_phones=cfg.dataset.phones_range[0], max_phones=cfg.dataset.phones_range[1], min_duration=cfg.dataset.duration_range[0], @@ -134,8 +150,9 @@ class Dataset(_Dataset): else: self.paths = paths - self.spkr_symmap = spkr_symmap or self._get_spkr_symmap() self.phone_symmap = phone_symmap or self._get_phone_symmap() + self.spkr_symmap = spkr_symmap or self._get_spkr_symmap() + self.task_symmap = get_task_symmap or self._get_task_symmap() self.training = training # assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8" @@ -169,6 +186,7 @@ class Dataset(_Dataset): self.durations[spkr_id] = duration else: self.durations[spkr_id] += duration + def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]): ret = defaultdict(list) for path in self.paths: @@ -181,16 +199,29 @@ class Dataset(_Dataset): def phones(self): return sorted(set().union(*[_get_phones(path) for path in self.paths])) - def _get_phone_symmap(self): - return get_phone_symmap() - @cached_property def spkrs(self): return sorted({cfg.get_spkr(path) for path in self.paths}) + @cached_property + def tasks(self): + return ["tts"] # "ns", "sr", "tse", "cse", "nse" + + def _get_phone_symmap(self): + return get_phone_symmap() + def _get_spkr_symmap(self): return {s: i for i, s in enumerate(self.spkrs)} + def _get_task_symmap(self): + return get_task_symmap() + + def get_task_token( token ): + return torch.Tensor([[ self.tasks_symmap[f'<{token}>'] for _ in range(len(cfg.models.prom_levels)) ]], dtype=torch.int16) + + def sample_noise(self): + ... + def sample_speakers(self, ignore=[]): choices = set(self.spkrs) - set(ignore) return random.choice([*choices]) @@ -212,17 +243,7 @@ class Dataset(_Dataset): # shuffle it up a bit offset = random.randint(-16, 16) - trim_length = int(cfg.dataset.prompt_duration * 75) + offset - def trim( qnt ): - length = qnt.shape[0] - start = int(length * random.random()) - end = start + trim_length - if end >= length: - start = length - trim_length - end = length - - return qnt[start:end] - + trim_length = int(cfg.dataset.prompt_duration * 75) + offset total_qnt_length = 0 for _ in range(cfg.dataset.max_prompts): path = random.choice(choices) @@ -234,7 +255,7 @@ class Dataset(_Dataset): qnt = _load_quants(path) if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]: - qnt = trim(qnt) + qnt = trim_random( qnt, trim_length ) prom_list.append(qnt) total_qnt_length += qnt.shape[0] @@ -248,14 +269,10 @@ class Dataset(_Dataset): prom = torch.cat(prom_list) if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]: - prom = trim(prom) + prom = trim_random( prom, trim_length ) return prom - @cached_property - def tasks(self): - return ["tts"] # "ns", "sr", "tse", "cse", "nse" - def __getitem__(self, index): if cfg.dataset.sample_type == "speaker": spkr_name = self.spkrs[index] @@ -275,34 +292,45 @@ class Dataset(_Dataset): resps = _load_quants(path) task = random.choice(self.tasks) + # text-to-speech if task == "tts": proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps - # noise-suppression + """ - elif task == "ns": - proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps + # noise suppression || speech removal + elif task == "ns" or task == "sr": + # sample random noise noise = self.sample_noise() - noise = extend_audio(noise, proms.shape[0]) - proms = merge_audio(proms, noise) - # something to prepend a ns token to the beginning of proms - elif task == "sr": - proms = resps - resps = self.sample_noise() - resps = extend_audio(resps, proms.shape[0]) - # something to prepend a sr token to the beginning of proms + # extend the noise to fill the target audio + noise = repeat_extend_audio(noise, resps.shape[0]) + # create the input prompt by merging the target audio with the noise + proms = merge_audio(resps, noise) + # set the target to just be the noise if + if task == "sr": + resps = noise + # prepend the task token + proms = torch.cat( [self.get_task_token(task), proms] ) + # target speech extraction elif task == "tse": - proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps - other_speaker = self.sample_speaker(ignore=[spkr_name]) - other_proms = self.sample_prompts(other_speaker, ignore="") - proms = merge_audio(proms, other_proms) - # something to prepend a tse token to the beginning of proms - """ + # sample a random, clean, utterance for the target speaker + clean_proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps + # sample a random, clean utterance from a different speaker + other_proms = self.sample_prompts(self.sample_speaker(ignore=[spkr_name]), ignore="") + # overlay the random speaker over the target audio + noisy_proms = merge_audio(resps, other_proms) + # stitch together the promps + proms = torch.cat( [clean_proms, self.get_task_token(task), noisy_proms] ) """ # speech editing would require higher quality transcription data (phoneme level/word level) unfortunately # as I need to get a good clean point to trim into + """ + # clean speech editing elif task == "cse": + ... + # noisy speech editing elif task == "nse": + ... """ diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 4b425a5..0c74da6 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -128,7 +128,7 @@ def _replace_file_extension(path, suffix): @torch.inference_mode() -def encode(wav: Tensor, sr: int, device="cuda"): +def encode(wav: Tensor, sr: int = 24_000, device="cuda"): """ Args: wav: (t) @@ -177,6 +177,36 @@ def encode_from_file(path, device="cuda"): return qnt +# Helper Functions + +# trims a random piece of audio, up to `target` +def trim_random( qnt, target ): + length = qnt.shape[0] + start = int(length * random.random()) + end = start + target + if end >= length: + start = length - target + end = length + + return qnt[start:end] + +# repeats the audio to fit the target size +def repeat_extend_audio( qnt, target ): + pieces = [] + length = 0 + while length < target: + pieces.append(qnt) + length += qnt.shape[0] + + return trim_random(torch.cat(pieces), target) + +# merges two quantized audios together +# I don't know if this works +def merge_audio( *args, device="cpu" ): + qnts = [*args] + decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ] + combined = sum(decoded) / len(decoded) + return encode(combined, 24_000, device="cpu") def main(): parser = argparse.ArgumentParser() diff --git a/vall_e/inference.py b/vall_e/inference.py index 62371f6..e7f2ac4 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -5,6 +5,7 @@ import soundfile from einops import rearrange from .emb import g2p, qnt +from .emb.qnt import trim_random from .utils import to_device from .config import cfg @@ -12,17 +13,6 @@ from .models import get_models from .train import load_engines from .data import get_phone_symmap -import random - -def trim( qnt, trim_length ): - length = qnt.shape[0] - start = int(length * random.random()) - end = start + trim_length - if end >= length: - start = length - trim_length - end = length - return qnt[start:end] - class TTS(): def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ): self.loading = True @@ -91,7 +81,7 @@ class TTS(): enc = qnt.encode_from_file( path ) res = enc[0].t().to(torch.int16) if trim: - res = trim( res, int( 75 * cfg.dataset.duration_range[1] ) ) + res = trim_random( res, int( 75 * cfg.dataset.duration_range[1] ) ) return res