pseudocode polyfill stub some other flavor of working on adding the tasks
This commit is contained in:
parent
0b46c1e312
commit
bbb0563b3d
|
@ -311,8 +311,8 @@ class DeepSpeed:
|
|||
"quantization_period": 0
|
||||
},
|
||||
"modules": [
|
||||
"blocks",
|
||||
"retnet",
|
||||
"blocks", # for transformer-based models
|
||||
"retnet", # for RetNets-based models
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
104
vall_e/data.py
104
vall_e/data.py
|
@ -10,6 +10,7 @@ import random
|
|||
import torch
|
||||
|
||||
from .config import cfg
|
||||
from .emb.qnt import trim_random, repeat_extend_audio, merge_audio
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import cache, cached_property
|
||||
|
@ -33,6 +34,20 @@ def get_phone_symmap():
|
|||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
||||
return symmap
|
||||
|
||||
def get_task_symmap():
|
||||
start = 1024
|
||||
symmap = {
|
||||
"<tts>": -100,
|
||||
"<ns>": start + 0,
|
||||
"<sr>": start + 1,
|
||||
"<tse>": start + 2,
|
||||
"<soe>": start + 3,
|
||||
"<mask>": start + 4,
|
||||
"<eoe>": start + 5,
|
||||
"<svc>": start + 6,
|
||||
}
|
||||
return symmap
|
||||
|
||||
def _replace_file_extension(path, suffix):
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
|
@ -112,6 +127,7 @@ class Dataset(_Dataset):
|
|||
paths,
|
||||
phone_symmap=None,
|
||||
spkr_symmap=None,
|
||||
task_symmap=None,
|
||||
min_phones=cfg.dataset.phones_range[0],
|
||||
max_phones=cfg.dataset.phones_range[1],
|
||||
min_duration=cfg.dataset.duration_range[0],
|
||||
|
@ -134,8 +150,9 @@ class Dataset(_Dataset):
|
|||
else:
|
||||
self.paths = paths
|
||||
|
||||
self.spkr_symmap = spkr_symmap or self._get_spkr_symmap()
|
||||
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
||||
self.spkr_symmap = spkr_symmap or self._get_spkr_symmap()
|
||||
self.task_symmap = get_task_symmap or self._get_task_symmap()
|
||||
self.training = training
|
||||
|
||||
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
||||
|
@ -169,6 +186,7 @@ class Dataset(_Dataset):
|
|||
self.durations[spkr_id] = duration
|
||||
else:
|
||||
self.durations[spkr_id] += duration
|
||||
|
||||
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
|
||||
ret = defaultdict(list)
|
||||
for path in self.paths:
|
||||
|
@ -181,16 +199,29 @@ class Dataset(_Dataset):
|
|||
def phones(self):
|
||||
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
||||
|
||||
def _get_phone_symmap(self):
|
||||
return get_phone_symmap()
|
||||
|
||||
@cached_property
|
||||
def spkrs(self):
|
||||
return sorted({cfg.get_spkr(path) for path in self.paths})
|
||||
|
||||
@cached_property
|
||||
def tasks(self):
|
||||
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
|
||||
|
||||
def _get_phone_symmap(self):
|
||||
return get_phone_symmap()
|
||||
|
||||
def _get_spkr_symmap(self):
|
||||
return {s: i for i, s in enumerate(self.spkrs)}
|
||||
|
||||
def _get_task_symmap(self):
|
||||
return get_task_symmap()
|
||||
|
||||
def get_task_token( token ):
|
||||
return torch.Tensor([[ self.tasks_symmap[f'<{token}>'] for _ in range(len(cfg.models.prom_levels)) ]], dtype=torch.int16)
|
||||
|
||||
def sample_noise(self):
|
||||
...
|
||||
|
||||
def sample_speakers(self, ignore=[]):
|
||||
choices = set(self.spkrs) - set(ignore)
|
||||
return random.choice([*choices])
|
||||
|
@ -212,17 +243,7 @@ class Dataset(_Dataset):
|
|||
|
||||
# shuffle it up a bit
|
||||
offset = random.randint(-16, 16)
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + offset
|
||||
def trim( qnt ):
|
||||
length = qnt.shape[0]
|
||||
start = int(length * random.random())
|
||||
end = start + trim_length
|
||||
if end >= length:
|
||||
start = length - trim_length
|
||||
end = length
|
||||
|
||||
return qnt[start:end]
|
||||
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + offset
|
||||
total_qnt_length = 0
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
|
@ -234,7 +255,7 @@ class Dataset(_Dataset):
|
|||
qnt = _load_quants(path)
|
||||
|
||||
if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]:
|
||||
qnt = trim(qnt)
|
||||
qnt = trim_random( qnt, trim_length )
|
||||
|
||||
prom_list.append(qnt)
|
||||
total_qnt_length += qnt.shape[0]
|
||||
|
@ -248,14 +269,10 @@ class Dataset(_Dataset):
|
|||
prom = torch.cat(prom_list)
|
||||
|
||||
if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]:
|
||||
prom = trim(prom)
|
||||
prom = trim_random( prom, trim_length )
|
||||
|
||||
return prom
|
||||
|
||||
@cached_property
|
||||
def tasks(self):
|
||||
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
|
||||
|
||||
def __getitem__(self, index):
|
||||
if cfg.dataset.sample_type == "speaker":
|
||||
spkr_name = self.spkrs[index]
|
||||
|
@ -275,34 +292,45 @@ class Dataset(_Dataset):
|
|||
resps = _load_quants(path)
|
||||
|
||||
task = random.choice(self.tasks)
|
||||
|
||||
# text-to-speech
|
||||
if task == "tts":
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
# noise-suppression
|
||||
|
||||
"""
|
||||
elif task == "ns":
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
# noise suppression || speech removal
|
||||
elif task == "ns" or task == "sr":
|
||||
# sample random noise
|
||||
noise = self.sample_noise()
|
||||
noise = extend_audio(noise, proms.shape[0])
|
||||
proms = merge_audio(proms, noise)
|
||||
# something to prepend a ns token to the beginning of proms
|
||||
elif task == "sr":
|
||||
proms = resps
|
||||
resps = self.sample_noise()
|
||||
resps = extend_audio(resps, proms.shape[0])
|
||||
# something to prepend a sr token to the beginning of proms
|
||||
# extend the noise to fill the target audio
|
||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
proms = merge_audio(resps, noise)
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resps = noise
|
||||
# prepend the task token
|
||||
proms = torch.cat( [self.get_task_token(task), proms] )
|
||||
# target speech extraction
|
||||
elif task == "tse":
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
other_speaker = self.sample_speaker(ignore=[spkr_name])
|
||||
other_proms = self.sample_prompts(other_speaker, ignore="")
|
||||
proms = merge_audio(proms, other_proms)
|
||||
# something to prepend a tse token to the beginning of proms
|
||||
"""
|
||||
# sample a random, clean, utterance for the target speaker
|
||||
clean_proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
# sample a random, clean utterance from a different speaker
|
||||
other_proms = self.sample_prompts(self.sample_speaker(ignore=[spkr_name]), ignore="")
|
||||
# overlay the random speaker over the target audio
|
||||
noisy_proms = merge_audio(resps, other_proms)
|
||||
# stitch together the promps
|
||||
proms = torch.cat( [clean_proms, self.get_task_token(task), noisy_proms] )
|
||||
"""
|
||||
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
|
||||
# as I need to get a good clean point to trim into
|
||||
"""
|
||||
# clean speech editing
|
||||
elif task == "cse":
|
||||
...
|
||||
# noisy speech editing
|
||||
elif task == "nse":
|
||||
...
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -128,7 +128,7 @@ def _replace_file_extension(path, suffix):
|
|||
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode(wav: Tensor, sr: int, device="cuda"):
|
||||
def encode(wav: Tensor, sr: int = 24_000, device="cuda"):
|
||||
"""
|
||||
Args:
|
||||
wav: (t)
|
||||
|
@ -177,6 +177,36 @@ def encode_from_file(path, device="cuda"):
|
|||
|
||||
return qnt
|
||||
|
||||
# Helper Functions
|
||||
|
||||
# trims a random piece of audio, up to `target`
|
||||
def trim_random( qnt, target ):
|
||||
length = qnt.shape[0]
|
||||
start = int(length * random.random())
|
||||
end = start + target
|
||||
if end >= length:
|
||||
start = length - target
|
||||
end = length
|
||||
|
||||
return qnt[start:end]
|
||||
|
||||
# repeats the audio to fit the target size
|
||||
def repeat_extend_audio( qnt, target ):
|
||||
pieces = []
|
||||
length = 0
|
||||
while length < target:
|
||||
pieces.append(qnt)
|
||||
length += qnt.shape[0]
|
||||
|
||||
return trim_random(torch.cat(pieces), target)
|
||||
|
||||
# merges two quantized audios together
|
||||
# I don't know if this works
|
||||
def merge_audio( *args, device="cpu" ):
|
||||
qnts = [*args]
|
||||
decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ]
|
||||
combined = sum(decoded) / len(decoded)
|
||||
return encode(combined, 24_000, device="cpu")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
|
@ -5,6 +5,7 @@ import soundfile
|
|||
from einops import rearrange
|
||||
|
||||
from .emb import g2p, qnt
|
||||
from .emb.qnt import trim_random
|
||||
from .utils import to_device
|
||||
|
||||
from .config import cfg
|
||||
|
@ -12,17 +13,6 @@ from .models import get_models
|
|||
from .train import load_engines
|
||||
from .data import get_phone_symmap
|
||||
|
||||
import random
|
||||
|
||||
def trim( qnt, trim_length ):
|
||||
length = qnt.shape[0]
|
||||
start = int(length * random.random())
|
||||
end = start + trim_length
|
||||
if end >= length:
|
||||
start = length - trim_length
|
||||
end = length
|
||||
return qnt[start:end]
|
||||
|
||||
class TTS():
|
||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
|
||||
self.loading = True
|
||||
|
@ -91,7 +81,7 @@ class TTS():
|
|||
enc = qnt.encode_from_file( path )
|
||||
res = enc[0].t().to(torch.int16)
|
||||
if trim:
|
||||
res = trim( res, int( 75 * cfg.dataset.duration_range[1] ) )
|
||||
res = trim_random( res, int( 75 * cfg.dataset.duration_range[1] ) )
|
||||
return res
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user