tested the training preparation for tasks ns, sr, and tse (I don't expect it to go well with only 2 RVQ bins)
This commit is contained in:
parent
bbb0563b3d
commit
77292c42f9
|
@ -111,6 +111,7 @@ class _Config:
|
|||
class Dataset:
|
||||
training: list[Path] = field(default_factory=lambda: [])
|
||||
validation: list[Path] = field(default_factory=lambda: [])
|
||||
noise: list[Path] = field(default_factory=lambda: [])
|
||||
|
||||
temp: list[Path] = field(default_factory=lambda: [])
|
||||
|
||||
|
@ -393,6 +394,7 @@ class Trainer:
|
|||
|
||||
@dataclass()
|
||||
class Inference:
|
||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||
use_vocos: bool = True
|
||||
|
||||
@dataclass()
|
||||
|
@ -473,6 +475,8 @@ try:
|
|||
if not cfg.dataset.use_hdf5:
|
||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||
|
||||
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
|
182
vall_e/data.py
182
vall_e/data.py
|
@ -10,7 +10,7 @@ import random
|
|||
import torch
|
||||
|
||||
from .config import cfg
|
||||
from .emb.qnt import trim_random, repeat_extend_audio, merge_audio
|
||||
from .emb.qnt import trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import cache, cached_property
|
||||
|
@ -77,7 +77,6 @@ def _get_phones(path, lang_marker="en"):
|
|||
split = content.split(" ")
|
||||
return [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
|
||||
|
||||
def _interleaved_reorder(l, fn):
|
||||
groups = defaultdict(list)
|
||||
for e in l:
|
||||
|
@ -216,11 +215,17 @@ class Dataset(_Dataset):
|
|||
def _get_task_symmap(self):
|
||||
return get_task_symmap()
|
||||
|
||||
def get_task_token( token ):
|
||||
return torch.Tensor([[ self.tasks_symmap[f'<{token}>'] for _ in range(len(cfg.models.prom_levels)) ]], dtype=torch.int16)
|
||||
def get_task_token( self, token ):
|
||||
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16)
|
||||
|
||||
def sample_noise(self):
|
||||
...
|
||||
paths = []
|
||||
print(cfg.dataset.noise)
|
||||
for data_dir in cfg.dataset.noise:
|
||||
paths.extend(data_dir.rglob("*.qnt.pt"))
|
||||
|
||||
path = random.choice(paths)
|
||||
return _load_quants(path)
|
||||
|
||||
def sample_speakers(self, ignore=[]):
|
||||
choices = set(self.spkrs) - set(ignore)
|
||||
|
@ -242,14 +247,13 @@ class Dataset(_Dataset):
|
|||
"""
|
||||
|
||||
# shuffle it up a bit
|
||||
offset = random.randint(-16, 16)
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + offset
|
||||
total_qnt_length = 0
|
||||
prom_length = 0
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-16, 16)
|
||||
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
#qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||
else:
|
||||
qnt = _load_quants(path)
|
||||
|
@ -258,12 +262,9 @@ class Dataset(_Dataset):
|
|||
qnt = trim_random( qnt, trim_length )
|
||||
|
||||
prom_list.append(qnt)
|
||||
total_qnt_length += qnt.shape[0]
|
||||
prom_length += qnt.shape[0]
|
||||
|
||||
if total_qnt_length >= trim_length:
|
||||
break
|
||||
|
||||
if random.random() > cfg.dataset.random_utterance:
|
||||
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
|
||||
break
|
||||
|
||||
prom = torch.cat(prom_list)
|
||||
|
@ -296,32 +297,47 @@ class Dataset(_Dataset):
|
|||
# text-to-speech
|
||||
if task == "tts":
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
|
||||
"""
|
||||
# noise suppression || speech removal
|
||||
elif task == "ns" or task == "sr":
|
||||
# sample random noise
|
||||
noise = self.sample_noise()
|
||||
|
||||
# extend the noise to fill the target audio
|
||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
proms = merge_audio(resps, noise)
|
||||
proms = merge_audio(resps, noise, scale=[1, 0.125])
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resps = noise
|
||||
# prepend the task token
|
||||
proms = torch.cat( [self.get_task_token(task), proms] )
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([1, 2]).to(self.text_dtype)
|
||||
# target speech extraction
|
||||
elif task == "tse":
|
||||
# sample a random, clean, utterance for the target speaker
|
||||
clean_proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
# sample a random, clean utterance from a different speaker
|
||||
other_proms = self.sample_prompts(self.sample_speaker(ignore=[spkr_name]), ignore="")
|
||||
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
|
||||
# overlay the random speaker over the target audio
|
||||
noisy_proms = merge_audio(resps, other_proms)
|
||||
# stitch together the promps
|
||||
|
||||
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
||||
if other_proms.shape[0] == smallest_size:
|
||||
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)] )
|
||||
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
||||
else:
|
||||
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)] )
|
||||
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
||||
|
||||
# stitch together the promps
|
||||
proms = torch.cat( [clean_proms, self.get_task_token(task), noisy_proms] )
|
||||
"""
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([1, 2]).to(self.text_dtype)
|
||||
|
||||
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
|
||||
# as I need to get a good clean point to trim into
|
||||
"""
|
||||
|
@ -332,6 +348,29 @@ class Dataset(_Dataset):
|
|||
elif task == "nse":
|
||||
...
|
||||
"""
|
||||
|
||||
"""
|
||||
# emulate SVC
|
||||
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
|
||||
# targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip
|
||||
|
||||
# NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical
|
||||
# ahead-of-time dataset preparation of a shit ton of RVC clips might be the key.
|
||||
# aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task.
|
||||
elif task == "svc":
|
||||
# sample a random, clean utterance for the target speaker
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
# sample a reference clip from a different speaker
|
||||
ref_proms = self.sample_rvc(self.sample_speakers(ignore=[spkr_name]))
|
||||
#
|
||||
resps =
|
||||
# stitch together the promps
|
||||
proms = torch.cat( [proms, self.get_task_token(task), ref_proms] )
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([1, 2]).to(self.text_dtype)
|
||||
"""
|
||||
|
||||
|
||||
return dict(
|
||||
|
@ -608,23 +647,98 @@ if __name__ == "__main__":
|
|||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--create-hdf5", action="store_true")
|
||||
parser.add_argument("--task", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.create_hdf5:
|
||||
task = args.task
|
||||
|
||||
if args.task == "hdf5":
|
||||
create_dataset_hdf5()
|
||||
elif args.task == "sample":
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
samples = {
|
||||
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||
}
|
||||
|
||||
samples = {
|
||||
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||
}
|
||||
for k, v in samples.items():
|
||||
for i in range(len(v)):
|
||||
del v[i]['proms']
|
||||
del v[i]['resps']
|
||||
print(f'{k}:', v)
|
||||
"""
|
||||
elif args.task == "tasks":
|
||||
index = 0
|
||||
task = "ns"
|
||||
|
||||
for k, v in samples.items():
|
||||
for i in range(len(v)):
|
||||
del v[i]['proms']
|
||||
del v[i]['resps']
|
||||
print(f'{k}:', v)
|
||||
train_dataset, val_dataset = create_datasets()
|
||||
train_dataset.task_symmap = get_task_symmap()
|
||||
|
||||
if cfg.dataset.sample_type == "speaker":
|
||||
spkr_name = train_dataset.spkrs[index]
|
||||
spkr_id = train_dataset.spkr_symmap[spkr_name]
|
||||
path = random.choice([*set(train_dataset.paths_by_spkr_name[spkr_name])])
|
||||
else:
|
||||
path = train_dataset.paths[index]
|
||||
spkr_name = cfg.get_spkr(path)
|
||||
spkr_id = train_dataset.spkr_symmap[spkr_name]
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(train_dataset.text_dtype)
|
||||
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||
else:
|
||||
text = torch.tensor([*map(train_dataset.phone_symmap.get, _get_phones(path))]).to(train_dataset.text_dtype)
|
||||
resps = _load_quants(path)
|
||||
|
||||
noise = None
|
||||
if task == "ns" or task == "sr":
|
||||
# sample random noise
|
||||
noise = train_dataset.sample_noise()
|
||||
|
||||
decode_to_file( noise, "./.noise.wav", device="cpu" )
|
||||
|
||||
# extend the noise to fill the target audio
|
||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
proms = merge_audio(resps, noise, scale=[1, 0.125])
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resps = noise
|
||||
# prepend the task token
|
||||
proms = torch.cat( [train_dataset.get_task_token(task), proms] )
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
|
||||
# target speech extraction
|
||||
elif task == "tse":
|
||||
# sample a random, clean, utterance for the target speaker
|
||||
clean_proms = train_dataset.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
# sample a random, clean utterance from a different speaker
|
||||
other_proms = train_dataset.sample_prompts(train_dataset.sample_speakers(ignore=[spkr_name]), ignore="")
|
||||
# overlay the random speaker over the target audio
|
||||
|
||||
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
||||
if other_proms.shape[0] == smallest_size:
|
||||
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms )
|
||||
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
||||
else:
|
||||
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :] )
|
||||
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
||||
|
||||
# stitch together the promps
|
||||
proms = torch.cat( [clean_proms, train_dataset.get_task_token(task), noisy_proms] )
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
|
||||
|
||||
decode_to_file( proms, "./.proms.wav", device="cpu" )
|
||||
decode_to_file( resps, "./.resps.wav", device="cpu" )
|
||||
|
||||
if noise is not None:
|
||||
decode_to_file( noise, "./.noise-fill.wav", device="cpu" )
|
||||
"""
|
|
@ -37,6 +37,7 @@ def _load_encodec_model(device="cuda"):
|
|||
model.set_target_bandwidth(bandwidth_id)
|
||||
model.bandwidth_id = bandwidth_id
|
||||
model.sample_rate = cfg.sample_rate
|
||||
model.normalize = cfg.inference.normalize
|
||||
model.backend = "encodec"
|
||||
|
||||
return model
|
||||
|
@ -202,25 +203,32 @@ def repeat_extend_audio( qnt, target ):
|
|||
|
||||
# merges two quantized audios together
|
||||
# I don't know if this works
|
||||
def merge_audio( *args, device="cpu" ):
|
||||
def merge_audio( *args, device="cpu", scale=[] ):
|
||||
qnts = [*args]
|
||||
decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ]
|
||||
|
||||
if len(scale) == len(decoded):
|
||||
for i in range(len(scale)):
|
||||
decoded[i] = decoded[i] * scale[i]
|
||||
|
||||
combined = sum(decoded) / len(decoded)
|
||||
return encode(combined, 24_000, device="cpu")
|
||||
return encode(combined, 24_000, device="cpu")[0].t()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("folder", type=Path)
|
||||
parser.add_argument("--suffix", default=".wav")
|
||||
parser.add_argument("--device", default="cuda")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = args.device
|
||||
paths = [*args.folder.rglob(f"*{args.suffix}")]
|
||||
|
||||
for path in tqdm(paths):
|
||||
out_path = _replace_file_extension(path, ".qnt.pt")
|
||||
if out_path.exists():
|
||||
continue
|
||||
qnt = encode_from_file(path)
|
||||
qnt = encode_from_file(path, device=device)
|
||||
torch.save(qnt.cpu(), out_path)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user