tested the training preparation for tasks ns, sr, and tse (I don't expect it to go well with only 2 RVQ bins)
This commit is contained in:
parent
bbb0563b3d
commit
77292c42f9
|
@ -111,6 +111,7 @@ class _Config:
|
||||||
class Dataset:
|
class Dataset:
|
||||||
training: list[Path] = field(default_factory=lambda: [])
|
training: list[Path] = field(default_factory=lambda: [])
|
||||||
validation: list[Path] = field(default_factory=lambda: [])
|
validation: list[Path] = field(default_factory=lambda: [])
|
||||||
|
noise: list[Path] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
temp: list[Path] = field(default_factory=lambda: [])
|
temp: list[Path] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
@ -393,6 +394,7 @@ class Trainer:
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Inference:
|
class Inference:
|
||||||
|
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||||
use_vocos: bool = True
|
use_vocos: bool = True
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
|
@ -473,6 +475,8 @@ try:
|
||||||
if not cfg.dataset.use_hdf5:
|
if not cfg.dataset.use_hdf5:
|
||||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||||
|
|
||||||
|
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
182
vall_e/data.py
182
vall_e/data.py
|
@ -10,7 +10,7 @@ import random
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .emb.qnt import trim_random, repeat_extend_audio, merge_audio
|
from .emb.qnt import trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import cache, cached_property
|
from functools import cache, cached_property
|
||||||
|
@ -77,7 +77,6 @@ def _get_phones(path, lang_marker="en"):
|
||||||
split = content.split(" ")
|
split = content.split(" ")
|
||||||
return [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
return [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||||
|
|
||||||
|
|
||||||
def _interleaved_reorder(l, fn):
|
def _interleaved_reorder(l, fn):
|
||||||
groups = defaultdict(list)
|
groups = defaultdict(list)
|
||||||
for e in l:
|
for e in l:
|
||||||
|
@ -216,11 +215,17 @@ class Dataset(_Dataset):
|
||||||
def _get_task_symmap(self):
|
def _get_task_symmap(self):
|
||||||
return get_task_symmap()
|
return get_task_symmap()
|
||||||
|
|
||||||
def get_task_token( token ):
|
def get_task_token( self, token ):
|
||||||
return torch.Tensor([[ self.tasks_symmap[f'<{token}>'] for _ in range(len(cfg.models.prom_levels)) ]], dtype=torch.int16)
|
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16)
|
||||||
|
|
||||||
def sample_noise(self):
|
def sample_noise(self):
|
||||||
...
|
paths = []
|
||||||
|
print(cfg.dataset.noise)
|
||||||
|
for data_dir in cfg.dataset.noise:
|
||||||
|
paths.extend(data_dir.rglob("*.qnt.pt"))
|
||||||
|
|
||||||
|
path = random.choice(paths)
|
||||||
|
return _load_quants(path)
|
||||||
|
|
||||||
def sample_speakers(self, ignore=[]):
|
def sample_speakers(self, ignore=[]):
|
||||||
choices = set(self.spkrs) - set(ignore)
|
choices = set(self.spkrs) - set(ignore)
|
||||||
|
@ -242,14 +247,13 @@ class Dataset(_Dataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# shuffle it up a bit
|
# shuffle it up a bit
|
||||||
offset = random.randint(-16, 16)
|
prom_length = 0
|
||||||
trim_length = int(cfg.dataset.prompt_duration * 75) + offset
|
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-16, 16)
|
||||||
total_qnt_length = 0
|
|
||||||
for _ in range(cfg.dataset.max_prompts):
|
for _ in range(cfg.dataset.max_prompts):
|
||||||
path = random.choice(choices)
|
path = random.choice(choices)
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
#qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
|
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
qnt = _load_quants(path)
|
qnt = _load_quants(path)
|
||||||
|
@ -258,12 +262,9 @@ class Dataset(_Dataset):
|
||||||
qnt = trim_random( qnt, trim_length )
|
qnt = trim_random( qnt, trim_length )
|
||||||
|
|
||||||
prom_list.append(qnt)
|
prom_list.append(qnt)
|
||||||
total_qnt_length += qnt.shape[0]
|
prom_length += qnt.shape[0]
|
||||||
|
|
||||||
if total_qnt_length >= trim_length:
|
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
|
||||||
break
|
|
||||||
|
|
||||||
if random.random() > cfg.dataset.random_utterance:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
prom = torch.cat(prom_list)
|
prom = torch.cat(prom_list)
|
||||||
|
@ -296,32 +297,47 @@ class Dataset(_Dataset):
|
||||||
# text-to-speech
|
# text-to-speech
|
||||||
if task == "tts":
|
if task == "tts":
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
|
|
||||||
"""
|
|
||||||
# noise suppression || speech removal
|
# noise suppression || speech removal
|
||||||
elif task == "ns" or task == "sr":
|
elif task == "ns" or task == "sr":
|
||||||
# sample random noise
|
# sample random noise
|
||||||
noise = self.sample_noise()
|
noise = self.sample_noise()
|
||||||
|
|
||||||
# extend the noise to fill the target audio
|
# extend the noise to fill the target audio
|
||||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||||
# create the input prompt by merging the target audio with the noise
|
# create the input prompt by merging the target audio with the noise
|
||||||
proms = merge_audio(resps, noise)
|
proms = merge_audio(resps, noise, scale=[1, 0.125])
|
||||||
# set the target to just be the noise if <sr>
|
# set the target to just be the noise if <sr>
|
||||||
if task == "sr":
|
if task == "sr":
|
||||||
resps = noise
|
resps = noise
|
||||||
# prepend the task token
|
# prepend the task token
|
||||||
proms = torch.cat( [self.get_task_token(task), proms] )
|
proms = torch.cat( [self.get_task_token(task), proms] )
|
||||||
|
|
||||||
|
# set the text prompt to empty to train without a guided text prompt
|
||||||
|
if random.random() < 0.5:
|
||||||
|
text = torch.tensor([1, 2]).to(self.text_dtype)
|
||||||
# target speech extraction
|
# target speech extraction
|
||||||
elif task == "tse":
|
elif task == "tse":
|
||||||
# sample a random, clean, utterance for the target speaker
|
# sample a random, clean, utterance for the target speaker
|
||||||
clean_proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
clean_proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
# sample a random, clean utterance from a different speaker
|
# sample a random, clean utterance from a different speaker
|
||||||
other_proms = self.sample_prompts(self.sample_speaker(ignore=[spkr_name]), ignore="")
|
other_proms = self.sample_prompts(self.sample_speakers(ignore=[spkr_name]), ignore="")
|
||||||
# overlay the random speaker over the target audio
|
# overlay the random speaker over the target audio
|
||||||
noisy_proms = merge_audio(resps, other_proms)
|
|
||||||
# stitch together the promps
|
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
||||||
|
if other_proms.shape[0] == smallest_size:
|
||||||
|
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)] )
|
||||||
|
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
||||||
|
else:
|
||||||
|
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)] )
|
||||||
|
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
||||||
|
|
||||||
|
# stitch together the promps
|
||||||
proms = torch.cat( [clean_proms, self.get_task_token(task), noisy_proms] )
|
proms = torch.cat( [clean_proms, self.get_task_token(task), noisy_proms] )
|
||||||
"""
|
|
||||||
|
# set the text prompt to empty to train without a guided text prompt
|
||||||
|
if random.random() < 0.5:
|
||||||
|
text = torch.tensor([1, 2]).to(self.text_dtype)
|
||||||
|
|
||||||
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
|
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
|
||||||
# as I need to get a good clean point to trim into
|
# as I need to get a good clean point to trim into
|
||||||
"""
|
"""
|
||||||
|
@ -332,6 +348,29 @@ class Dataset(_Dataset):
|
||||||
elif task == "nse":
|
elif task == "nse":
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
# emulate SVC
|
||||||
|
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
|
||||||
|
# targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip
|
||||||
|
|
||||||
|
# NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical
|
||||||
|
# ahead-of-time dataset preparation of a shit ton of RVC clips might be the key.
|
||||||
|
# aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task.
|
||||||
|
elif task == "svc":
|
||||||
|
# sample a random, clean utterance for the target speaker
|
||||||
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
|
# sample a reference clip from a different speaker
|
||||||
|
ref_proms = self.sample_rvc(self.sample_speakers(ignore=[spkr_name]))
|
||||||
|
#
|
||||||
|
resps =
|
||||||
|
# stitch together the promps
|
||||||
|
proms = torch.cat( [proms, self.get_task_token(task), ref_proms] )
|
||||||
|
|
||||||
|
# set the text prompt to empty to train without a guided text prompt
|
||||||
|
if random.random() < 0.5:
|
||||||
|
text = torch.tensor([1, 2]).to(self.text_dtype)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
|
@ -608,23 +647,98 @@ if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||||
parser.add_argument("--create-hdf5", action="store_true")
|
parser.add_argument("--task", type=str)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.create_hdf5:
|
task = args.task
|
||||||
|
|
||||||
|
if args.task == "hdf5":
|
||||||
create_dataset_hdf5()
|
create_dataset_hdf5()
|
||||||
|
elif args.task == "sample":
|
||||||
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
samples = {
|
||||||
|
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||||
|
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||||
|
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||||
|
}
|
||||||
|
|
||||||
samples = {
|
for k, v in samples.items():
|
||||||
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
for i in range(len(v)):
|
||||||
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
del v[i]['proms']
|
||||||
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
del v[i]['resps']
|
||||||
}
|
print(f'{k}:', v)
|
||||||
|
"""
|
||||||
|
elif args.task == "tasks":
|
||||||
|
index = 0
|
||||||
|
task = "ns"
|
||||||
|
|
||||||
for k, v in samples.items():
|
train_dataset, val_dataset = create_datasets()
|
||||||
for i in range(len(v)):
|
train_dataset.task_symmap = get_task_symmap()
|
||||||
del v[i]['proms']
|
|
||||||
del v[i]['resps']
|
|
||||||
print(f'{k}:', v)
|
|
||||||
|
|
||||||
|
if cfg.dataset.sample_type == "speaker":
|
||||||
|
spkr_name = train_dataset.spkrs[index]
|
||||||
|
spkr_id = train_dataset.spkr_symmap[spkr_name]
|
||||||
|
path = random.choice([*set(train_dataset.paths_by_spkr_name[spkr_name])])
|
||||||
|
else:
|
||||||
|
path = train_dataset.paths[index]
|
||||||
|
spkr_name = cfg.get_spkr(path)
|
||||||
|
spkr_id = train_dataset.spkr_symmap[spkr_name]
|
||||||
|
|
||||||
|
if cfg.dataset.use_hdf5:
|
||||||
|
key = _get_hdf5_path(path)
|
||||||
|
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(train_dataset.text_dtype)
|
||||||
|
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||||
|
else:
|
||||||
|
text = torch.tensor([*map(train_dataset.phone_symmap.get, _get_phones(path))]).to(train_dataset.text_dtype)
|
||||||
|
resps = _load_quants(path)
|
||||||
|
|
||||||
|
noise = None
|
||||||
|
if task == "ns" or task == "sr":
|
||||||
|
# sample random noise
|
||||||
|
noise = train_dataset.sample_noise()
|
||||||
|
|
||||||
|
decode_to_file( noise, "./.noise.wav", device="cpu" )
|
||||||
|
|
||||||
|
# extend the noise to fill the target audio
|
||||||
|
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||||
|
# create the input prompt by merging the target audio with the noise
|
||||||
|
proms = merge_audio(resps, noise, scale=[1, 0.125])
|
||||||
|
# set the target to just be the noise if <sr>
|
||||||
|
if task == "sr":
|
||||||
|
resps = noise
|
||||||
|
# prepend the task token
|
||||||
|
proms = torch.cat( [train_dataset.get_task_token(task), proms] )
|
||||||
|
|
||||||
|
# set the text prompt to empty to train without a guided text prompt
|
||||||
|
if random.random() < 0.5:
|
||||||
|
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
|
||||||
|
# target speech extraction
|
||||||
|
elif task == "tse":
|
||||||
|
# sample a random, clean, utterance for the target speaker
|
||||||
|
clean_proms = train_dataset.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
|
# sample a random, clean utterance from a different speaker
|
||||||
|
other_proms = train_dataset.sample_prompts(train_dataset.sample_speakers(ignore=[spkr_name]), ignore="")
|
||||||
|
# overlay the random speaker over the target audio
|
||||||
|
|
||||||
|
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
||||||
|
if other_proms.shape[0] == smallest_size:
|
||||||
|
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms )
|
||||||
|
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
||||||
|
else:
|
||||||
|
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :] )
|
||||||
|
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
||||||
|
|
||||||
|
# stitch together the promps
|
||||||
|
proms = torch.cat( [clean_proms, train_dataset.get_task_token(task), noisy_proms] )
|
||||||
|
|
||||||
|
# set the text prompt to empty to train without a guided text prompt
|
||||||
|
if random.random() < 0.5:
|
||||||
|
text = torch.tensor([1, 2]).to(train_dataset.text_dtype)
|
||||||
|
|
||||||
|
decode_to_file( proms, "./.proms.wav", device="cpu" )
|
||||||
|
decode_to_file( resps, "./.resps.wav", device="cpu" )
|
||||||
|
|
||||||
|
if noise is not None:
|
||||||
|
decode_to_file( noise, "./.noise-fill.wav", device="cpu" )
|
||||||
|
"""
|
|
@ -37,6 +37,7 @@ def _load_encodec_model(device="cuda"):
|
||||||
model.set_target_bandwidth(bandwidth_id)
|
model.set_target_bandwidth(bandwidth_id)
|
||||||
model.bandwidth_id = bandwidth_id
|
model.bandwidth_id = bandwidth_id
|
||||||
model.sample_rate = cfg.sample_rate
|
model.sample_rate = cfg.sample_rate
|
||||||
|
model.normalize = cfg.inference.normalize
|
||||||
model.backend = "encodec"
|
model.backend = "encodec"
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -202,25 +203,32 @@ def repeat_extend_audio( qnt, target ):
|
||||||
|
|
||||||
# merges two quantized audios together
|
# merges two quantized audios together
|
||||||
# I don't know if this works
|
# I don't know if this works
|
||||||
def merge_audio( *args, device="cpu" ):
|
def merge_audio( *args, device="cpu", scale=[] ):
|
||||||
qnts = [*args]
|
qnts = [*args]
|
||||||
decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ]
|
decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ]
|
||||||
|
|
||||||
|
if len(scale) == len(decoded):
|
||||||
|
for i in range(len(scale)):
|
||||||
|
decoded[i] = decoded[i] * scale[i]
|
||||||
|
|
||||||
combined = sum(decoded) / len(decoded)
|
combined = sum(decoded) / len(decoded)
|
||||||
return encode(combined, 24_000, device="cpu")
|
return encode(combined, 24_000, device="cpu")[0].t()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("folder", type=Path)
|
parser.add_argument("folder", type=Path)
|
||||||
parser.add_argument("--suffix", default=".wav")
|
parser.add_argument("--suffix", default=".wav")
|
||||||
|
parser.add_argument("--device", default="cuda")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device = args.device
|
||||||
paths = [*args.folder.rglob(f"*{args.suffix}")]
|
paths = [*args.folder.rglob(f"*{args.suffix}")]
|
||||||
|
|
||||||
for path in tqdm(paths):
|
for path in tqdm(paths):
|
||||||
out_path = _replace_file_extension(path, ".qnt.pt")
|
out_path = _replace_file_extension(path, ".qnt.pt")
|
||||||
if out_path.exists():
|
if out_path.exists():
|
||||||
continue
|
continue
|
||||||
qnt = encode_from_file(path)
|
qnt = encode_from_file(path, device=device)
|
||||||
torch.save(qnt.cpu(), out_path)
|
torch.save(qnt.cpu(), out_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user