vall-e/vall_e/data.py

603 lines
19 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# todo: clean this mess up
import copy
import h5py
import json
import logging
import numpy as np
import os
import random
import torch
from .config import cfg
from collections import defaultdict
from functools import cache, cached_property
from itertools import groupby, zip_longest
from pathlib import Path
from typing import Any
from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__)
def get_phone_symmap():
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
return symmap
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_hdf5_path(path):
path = str(path)
if path[:2] != "./":
path = f'./{path}'
return path.replace(cfg.cfg_path, "")
def _get_quant_path(path):
return _replace_file_extension(path, ".qnt.pt")
def _get_phone_path(path):
return _replace_file_extension(path, ".phn.txt")
def _load_quants(path) -> Tensor:
path = _get_quant_path(path)
return torch.load(path)[0][:cfg.models.prom_levels, :].t().to(torch.int16)
@cache
def _get_phones(path, lang_marker="en"):
path = _get_phone_path(path)
with open(path, "r", encoding="utf8") as f:
content = f.read()
split = content.split(" ")
return [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
for e in l:
groups[fn(e)].append(e)
groups = {k: groups[k] for k in sorted(groups)}
for interleaved in zip_longest(*groups.values()):
for value in interleaved:
if value is not None:
yield value
@cache
def _validate(path, min_phones, max_phones, min_duration, max_duration):
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if key not in cfg.hdf5:
return False
phones = cfg.hdf5[key].attrs['phonemes']
duration = cfg.hdf5[key].attrs['duration']
if phones < min_phones or phones > max_phones:
return False
if duration < min_duration or duration > max_duration:
return False
return True
if not os.path.exists(_get_phone_path(path)) or not os.path.exists(_get_quant_path(path)):
return False
phones = _get_phones(path)
unique_phones = list(set(phones))
if len(unique_phones) == 0:
return False
if len(unique_phones) == 1 and unique_phones[0] == " ":
return False
if len(phones) < min_phones or len(phones) > max_phones:
return False
return True
class Dataset(_Dataset):
def __init__(
self,
paths,
phone_symmap=None,
spkr_symmap=None,
min_phones=cfg.dataset.phones_range[0],
max_phones=cfg.dataset.phones_range[1],
min_duration=cfg.dataset.duration_range[0],
max_duration=cfg.dataset.duration_range[1],
training=False,
extra_paths_by_spkr_name: dict[str, list] = {},
):
super().__init__()
self._head = None
self.min_phones = min_phones
self.max_phones = max_phones
self.min_duration = min_duration
self.max_duration = max_duration
self.sampler = None
if cfg.dataset.validate:
self.paths = [
path for path in paths if _validate(path, self.min_phones, self.max_phones, self.min_duration, self.max_duration)
]
else:
self.paths = paths
self.spkr_symmap = spkr_symmap or self._get_spkr_symmap()
self.phone_symmap = phone_symmap or self._get_phone_symmap()
self.training = training
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
if cfg.dataset.validate:
self.paths = [
p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
]
if cfg.dataset.sample_type == "path":
self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)]
if len(self.paths) == 0 and training:
raise ValueError("No valid path is found for training.")
self.duration = 0
self.durations = {}
if cfg.dataset.use_hdf5:
for path in self.paths:
key = _get_hdf5_path(path)
spkr_name = cfg.get_spkr(path)
spkr_id = self.spkr_symmap[spkr_name]
duration = cfg.hdf5[key].attrs['duration']
self.duration += duration
if spkr_id not in self.durations:
self.durations[spkr_id] = duration
else:
self.durations[spkr_id] += duration
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
ret = defaultdict(list)
for path in self.paths:
ret[cfg.get_spkr(path)].append(path)
for k, v in extra_paths_by_spkr_name.items():
ret[k].extend(v)
return {**ret}
@cached_property
def phones(self):
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
def _get_phone_symmap(self):
return get_phone_symmap()
@cached_property
def spkrs(self):
return sorted({cfg.get_spkr(path) for path in self.paths})
def _get_spkr_symmap(self):
return {s: i for i, s in enumerate(self.spkrs)}
def sample_speakers(self, ignore=[]):
choices = set(self.spkrs) - set(ignore)
return random.choice([*choices])
def sample_prompts(self, spkr_name, ignore):
prom_list = []
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
choices = [*choices]
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step
if len(choices) == 0:
choices = [*set(self.paths_by_spkr_name[spkr_name])]
"""
raise ValueError(
f"Failed to find another different utterance for {spkr_name}."
)
"""
# shuffle it up a bit
offset = random.randint(-16, 16)
trim_length = int(cfg.dataset.prompt_duration * 75) + offset
def trim( qnt ):
length = qnt.shape[0]
start = int(length * random.random())
end = start + trim_length
if end >= length:
start = length - trim_length
end = length
return qnt[start:end]
total_qnt_length = 0
for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices)
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
#qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
else:
qnt = _load_quants(path)
if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]:
qnt = trim(qnt)
prom_list.append(qnt)
total_qnt_length += qnt.shape[0]
if total_qnt_length >= trim_length:
break
if random.random() > cfg.dataset.random_utterance:
break
prom = torch.cat(prom_list)
if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]:
prom = trim(prom)
return prom
@cached_property
def tasks(self):
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
def __getitem__(self, index):
if cfg.dataset.sample_type == "speaker":
spkr_name = self.spkrs[index]
spkr_id = self.spkr_symmap[spkr_name]
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
else:
path = self.paths[index]
spkr_name = cfg.get_spkr(path)
spkr_id = self.spkr_symmap[spkr_name]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(self.text_dtype)
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
else:
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
resps = _load_quants(path)
task = random.choice(self.tasks)
# text-to-speech
if task == "tts":
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
# noise-suppression
"""
elif task == "ns":
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
noise = self.sample_noise()
noise = extend_audio(noise, proms.shape[0])
proms = merge_audio(proms, noise)
# something to prepend a ns token to the beginning of proms
elif task == "sr":
proms = resps
resps = self.sample_noise()
resps = extend_audio(resps, proms.shape[0])
# something to prepend a sr token to the beginning of proms
elif task == "tse":
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
other_speaker = self.sample_speaker(ignore=[spkr_name])
other_proms = self.sample_prompts(other_speaker, ignore="")
proms = merge_audio(proms, other_proms)
# something to prepend a tse token to the beginning of proms
"""
"""
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
elif task == "cse":
elif task == "nse":
"""
return dict(
index=index,
path=path,
spkr_name=spkr_name,
spkr_id=spkr_id,
task=task,
text=text,
proms=proms,
resps=resps,
)
def head_(self, n):
self._head = n
def training_(self, value):
self.training = value
def __len__(self):
if cfg.dataset.sample_type == "speaker":
return min(len(self.spkrs), self._head or len(self.spkrs))
return min(len(self.paths), self._head or len(self.paths))
def pin_memory(self):
self.text = self.text.pin_memory()
self.proms = self.proms.pin_memory()
self.resps = self.resps.pin_memory()
self.resp = self.resp.pin_memory()
return self
def collate_fn(samples: list[dict]):
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
return batch
def _seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def _create_dataloader(dataset, training):
sampler = None
shuffle = True
if cfg.distributed and training:
sampler = DistributedSampler(dataset)
shuffle = False
return DataLoader(
dataset=dataset,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
shuffle=shuffle,
drop_last=training,
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=True,
pin_memory=False, # True,
worker_init_fn=_seed_worker,
sampler=sampler,
)
def _load_dataset_paths():
hf = cfg.hdf5
paths = {
"training": [],
"validation": [],
}
datasets = {
"training": [],
"validation": [],
}
def get_paths( data_dir, type="training" ):
key = f"/{type}{_get_hdf5_path(data_dir)}"
if key not in cfg.hdf5:
return
paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ])
for data_dir in cfg.dataset.training:
get_paths( data_dir, "training" )
for data_dir in cfg.dataset.validation:
get_paths( data_dir, "validation" )
for _, type in enumerate(paths):
dirs = paths[type]
if len(dirs) == 0:
continue
dirs = [ Path(p) for p in dirs ]
pairs = sorted([(cfg.get_spkr(p), p) for p in dirs])
for _, group in groupby(pairs, lambda pair: pair[0]):
shuffled = sorted([p for _, p in group])
random.seed(0)
random.shuffle(shuffled)
datasets[type].extend(shuffled)
return datasets["training"], datasets["validation"]
# to-do: mirror the hdf5-based load function
def _load_train_val_paths():
paths = []
train_paths = []
val_paths = []
for data_dir in cfg.dataset.training:
paths.extend(data_dir.rglob("*.qnt.pt"))
if len(paths) > 0:
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
del paths
for _, group in groupby(pairs, lambda pair: pair[0]):
paths = sorted([p for _, p in group])
random.seed(0)
random.shuffle(paths)
train_paths.extend(paths)
for data_dir in cfg.dataset.validation:
paths.extend(data_dir.rglob("*.qnt.pt"))
if len(paths) > 0:
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
del paths
for _, group in groupby(pairs, lambda pair: pair[0]):
paths = sorted([p for _, p in group])
random.seed(0)
random.shuffle(paths)
val_paths.extend(paths)
train_paths, val_paths = map(sorted, [train_paths, val_paths])
if len(train_paths) == 0:
raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.")
# to-do: allow setting aside a fixed portion of the training dataset for validation
# something like the last X percent of each speaker is set aside
if len(val_paths) == 0:
val_paths = [ train_paths[0] ]
return train_paths, val_paths
@cfg.diskcache()
def create_datasets():
train_paths, val_paths = _load_dataset_paths() if cfg.dataset.use_hdf5 else _load_train_val_paths()
train_dataset = Dataset(
train_paths,
training=True,
)
val_dataset = Dataset(
val_paths,
train_dataset.phone_symmap,
#train_dataset.spkr_symmap,
#extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
)
val_dataset.head_(cfg.evaluation.size)
return train_dataset, val_dataset
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
subtrain_dataset = copy.deepcopy(train_dataset)
if cfg.dataset.sample_type == "path":
subtrain_dataset.head_(cfg.evaluation.size)
train_dl = _create_dataloader(train_dataset, training=True)
val_dl = _create_dataloader(val_dataset, training=False)
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
_logger.info(str(train_dataset.phone_symmap))
_logger.info(str(train_dataset.spkr_symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
"""
_logger.info(f"#durations (train): {str(train_dataset.durations)}.")
_logger.info(f"#durations (val): {str(val_dataset.durations)}.")
_logger.info(f"#durations (subtrain): {str(subtrain_dataset.durations)}.")
"""
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
_logger.info(f"#duration (subtrain): {str(subtrain_dataset.duration)}.")
assert isinstance(subtrain_dl.dataset, Dataset)
return train_dl, subtrain_dl, val_dl
# parse yaml to create an hdf5 tile
def create_dataset_hdf5():
symmap = get_phone_symmap()
root = cfg.cfg_path
hf = cfg.hdf5
def add( dir, type="training" ):
dir = "./" + str(dir)
name = dir.replace(root, "")
print( str(dir), name )
if not os.path.isdir(f'{root}/{name}/'):
return
# tqdm.write(f'{root}/{name}')
files = os.listdir(f'{root}/{name}/')
# grab IDs for every file
ids = { ".".join(file.split(".")[:-2]) for file in files }
for id in tqdm(ids, desc=f"Processing {name}"):
if not os.path.exists(f'{root}/{name}/{id}.qnt.pt') or not os.path.exists(f'{root}/{name}/{id}.phn.txt'):
continue
key = f'{type}/{name}/{id}'
if key in hf:
# print("Skipping existing entry:", key)
continue
group = hf.create_group(key)
# audio
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
# text
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
content = f.read()
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
for s in set(phones):
if s not in symmap:
symmap[s] = len(symmap.keys())
phn = [ symmap[s] for s in phones ]
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
# metadata
group.attrs['id'] = id
group.attrs['type'] = type
group.attrs['speaker'] = name
group.attrs['duration'] = qnt.shape[0] / 75
group.attrs['phonemes'] = len(phn)
# training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
add( data_dir, type="training" )
# validation
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
add( data_dir, type="validation" )
# write symmap
hf.create_dataset('symmap', data=json.dumps(symmap))
hf.close()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--create-hdf5", action="store_true")
args = parser.parse_args()
if args.create_hdf5:
create_dataset_hdf5()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
for k, v in samples.items():
for i in range(len(v)):
del v[i]['proms']
del v[i]['resps']
print(f'{k}:', v)