tested the training preparation for tasks ns, sr, and tse (I don't expect it to go well with only 2 RVQ bins)

This commit is contained in:
mrq 2023-08-18 23:55:40 -05:00
parent bbb0563b3d
commit 77292c42f9
3 changed files with 163 additions and 37 deletions

View File

@ -111,6 +111,7 @@ class _Config:
class Dataset:
training: list[Path] = field(default_factory=lambda: [])
validation: list[Path] = field(default_factory=lambda: [])
noise: list[Path] = field(default_factory=lambda: [])
temp: list[Path] = field(default_factory=lambda: [])
@ -393,6 +394,7 @@ class Trainer:
@dataclass()
class Inference:
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
use_vocos: bool = True
@dataclass()
@ -473,6 +475,8 @@ try:
if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
except Exception as e:
pass

View File

@ -10,7 +10,7 @@ import random
import torch
from .config import cfg
from .emb.qnt import trim_random, repeat_extend_audio, merge_audio
from .emb.qnt import trim_random, repeat_extend_audio, merge_audio, decode_to_file
from collections import defaultdict
from functools import cache, cached_property
@ -77,7 +77,6 @@ def _get_phones(path, lang_marker="en"):
split = content.split(" ")
return [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
for e in l:
@ -216,11 +215,17 @@ class Dataset(_Dataset):
def _get_task_symmap(self):
return get_task_symmap()
def get_task_token( token ):
return torch.Tensor([[ self.tasks_symmap[f'<{token}>'] for _ in range(len(cfg.models.prom_levels)) ]], dtype=torch.int16)
def get_task_token( self, token ):
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16)
def sample_noise(self):
...
paths = []
print(cfg.dataset.noise)
for data_dir in cfg.dataset.noise:
paths.extend(data_dir.rglob("*.qnt.pt"))
path = random.choice(paths)
return _load_quants(path)
def sample_speakers(self, ignore=[]):
choices = set(self.spkrs) - set(ignore)
@ -242,14 +247,13 @@ class Dataset(_Dataset):
"""
# shuffle it up a bit
offset = random.randint(-16, 16)
trim_length = int(cfg.dataset.prompt_duration * 75) + offset
total_qnt_length = 0
prom_length = 0
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-16, 16)
for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices)
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
#qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
else:
qnt = _load_quants(path)
@ -258,12 +262,9 @@ class Dataset(_Dataset):
qnt = trim_random( qnt, trim_length )
prom_list.append(qnt)
total_qnt_length += qnt.shape[0]
prom_length += qnt.shape[0]
if total_qnt_length >= trim_length:
break
if random.random() > cfg.dataset.random_utterance:
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
break
prom = torch.cat(prom_list)
@ -296,32 +297,47 @@ class Dataset(_Dataset):
# text-to-speech
if task == "tts":
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
"""
# noise suppression || speech removal
elif task == "ns" or task == "sr":
# sample random noise
noise = self.sample_noise()
# extend the noise to fill the target audio
noise = repeat_extend_audio(noise, resps.shape[0])
# create the input prompt by merging the target audio with the noise
proms = merge_audio(resps, noise)
proms = merge_audio(resps, noise, scale=[1, 0.125])
# set the target to just be the noise if <sr>
if task == "sr":
resps = noise
# prepend the task token
proms = torch.cat( [self.get_task_token(task), proms] )
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(self.text_dtype)
# target speech extraction
elif task == "tse":
# sample a random, clean, utterance for the target speaker
clean_proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# sample a random, clean utterance from a different speaker
other_proms = self.sample_prompts(self.sample_speaker(ignore=[spkr_name]), ignore="")
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
# overlay the random speaker over the target audio
noisy_proms = merge_audio(resps, other_proms)
# stitch together the promps
smallest_size = min(resps.shape[0], other_proms.shape[0])
if other_proms.shape[0] == smallest_size:
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)] )
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
else:
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)] )
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
# stitch together the promps
proms = torch.cat( [clean_proms, self.get_task_token(task), noisy_proms] )
"""
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(self.text_dtype)
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
"""
@ -332,6 +348,29 @@ class Dataset(_Dataset):
elif task == "nse":
...
"""
"""
# emulate SVC
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
# targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip
# NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical
# ahead-of-time dataset preparation of a shit ton of RVC clips might be the key.
# aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task.
elif task == "svc":
# sample a random, clean utterance for the target speaker
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# sample a reference clip from a different speaker
ref_proms = self.sample_rvc(self.sample_speakers(ignore=[spkr_name]))
#
resps =
# stitch together the promps
proms = torch.cat( [proms, self.get_task_token(task), ref_proms] )
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(self.text_dtype)
"""
return dict(
@ -608,23 +647,98 @@ if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--create-hdf5", action="store_true")
parser.add_argument("--task", type=str)
args = parser.parse_args()
if args.create_hdf5:
task = args.task
if args.task == "hdf5":
create_dataset_hdf5()
elif args.task == "sample":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
for k, v in samples.items():
for i in range(len(v)):
del v[i]['proms']
del v[i]['resps']
print(f'{k}:', v)
"""
elif args.task == "tasks":
index = 0
task = "ns"
for k, v in samples.items():
for i in range(len(v)):
del v[i]['proms']
del v[i]['resps']
print(f'{k}:', v)
train_dataset, val_dataset = create_datasets()
train_dataset.task_symmap = get_task_symmap()
if cfg.dataset.sample_type == "speaker":
spkr_name = train_dataset.spkrs[index]
spkr_id = train_dataset.spkr_symmap[spkr_name]
path = random.choice([*set(train_dataset.paths_by_spkr_name[spkr_name])])
else:
path = train_dataset.paths[index]
spkr_name = cfg.get_spkr(path)
spkr_id = train_dataset.spkr_symmap[spkr_name]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(train_dataset.text_dtype)
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
else:
text = torch.tensor([*map(train_dataset.phone_symmap.get, _get_phones(path))]).to(train_dataset.text_dtype)
resps = _load_quants(path)
noise = None
if task == "ns" or task == "sr":
# sample random noise
noise = train_dataset.sample_noise()
decode_to_file( noise, "./.noise.wav", device="cpu" )
# extend the noise to fill the target audio
noise = repeat_extend_audio(noise, resps.shape[0])
# create the input prompt by merging the target audio with the noise
proms = merge_audio(resps, noise, scale=[1, 0.125])
# set the target to just be the noise if <sr>
if task == "sr":
resps = noise
# prepend the task token
proms = torch.cat( [train_dataset.get_task_token(task), proms] )
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
# target speech extraction
elif task == "tse":
# sample a random, clean, utterance for the target speaker
clean_proms = train_dataset.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# sample a random, clean utterance from a different speaker
other_proms = train_dataset.sample_prompts(train_dataset.sample_speakers(ignore=[spkr_name]), ignore="")
# overlay the random speaker over the target audio
smallest_size = min(resps.shape[0], other_proms.shape[0])
if other_proms.shape[0] == smallest_size:
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms )
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
else:
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :] )
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
# stitch together the promps
proms = torch.cat( [clean_proms, train_dataset.get_task_token(task), noisy_proms] )
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
decode_to_file( proms, "./.proms.wav", device="cpu" )
decode_to_file( resps, "./.resps.wav", device="cpu" )
if noise is not None:
decode_to_file( noise, "./.noise-fill.wav", device="cpu" )
"""

View File

@ -37,6 +37,7 @@ def _load_encodec_model(device="cuda"):
model.set_target_bandwidth(bandwidth_id)
model.bandwidth_id = bandwidth_id
model.sample_rate = cfg.sample_rate
model.normalize = cfg.inference.normalize
model.backend = "encodec"
return model
@ -202,25 +203,32 @@ def repeat_extend_audio( qnt, target ):
# merges two quantized audios together
# I don't know if this works
def merge_audio( *args, device="cpu" ):
def merge_audio( *args, device="cpu", scale=[] ):
qnts = [*args]
decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ]
if len(scale) == len(decoded):
for i in range(len(scale)):
decoded[i] = decoded[i] * scale[i]
combined = sum(decoded) / len(decoded)
return encode(combined, 24_000, device="cpu")
return encode(combined, 24_000, device="cpu")[0].t()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("folder", type=Path)
parser.add_argument("--suffix", default=".wav")
parser.add_argument("--device", default="cuda")
args = parser.parse_args()
device = args.device
paths = [*args.folder.rglob(f"*{args.suffix}")]
for path in tqdm(paths):
out_path = _replace_file_extension(path, ".qnt.pt")
if out_path.exists():
continue
qnt = encode_from_file(path)
qnt = encode_from_file(path, device=device)
torch.save(qnt.cpu(), out_path)