vall-e/vall_e/data.py
2023-01-12 14:41:44 +08:00

295 lines
8.0 KiB
Python

import copy
import logging
import random
from collections import defaultdict
from functools import cache, cached_property
from itertools import groupby, zip_longest
from typing import Any
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from .config import cfg
from .sampler import Sampler
torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__)
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_quant_path(path):
return _replace_file_extension(path, ".qnt.pt")
def _load_quants(path) -> Tensor:
"""
Returns:
quants: (t q)
"""
path = _get_quant_path(path)
return torch.load(path)[0].t()
@cache
def _get_phones(path):
path = _replace_file_extension(path, ".phn.txt")
with open(path, "r", encoding="utf8") as f:
content = f.read()
return ["<s>"] + content.split() + ["</s>"]
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
for e in l:
groups[fn(e)].append(e)
groups = {k: groups[k] for k in sorted(groups)}
for interleaved in zip_longest(*groups.values()):
for value in interleaved:
if value is not None:
yield value
@cache
def _validate(path, min_phones, max_phones):
phones = _get_phones(path)
unique_phones = list(set(phones))
if len(unique_phones) == 0:
return False
if len(unique_phones) == 1 and unique_phones[0] == "_":
return False
if len(phones) < min_phones:
return False
if len(phones) > max_phones:
return False
return True
def _get_spkr_name(path) -> str:
return path.parts[-2] # spkr/*.wav
class VALLEDatset(Dataset):
def __init__(
self,
paths,
phone_symmap=None,
spkr_symmap=None,
min_phones=10,
max_phones=100,
training=False,
extra_paths_by_spkr_name: dict[str, list] = {},
):
super().__init__()
self._head = None
self.min_phones = min_phones
self.max_phones = max_phones
self.paths = [
path for path in paths if _validate(path, self.min_phones, self.max_phones)
]
self.spkr_symmap = spkr_symmap or self._get_spkr_symmap()
self.phone_symmap = phone_symmap or self._get_phone_symmap()
self.training = training
self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
if training:
self.sampler = Sampler(self.paths, [_get_spkr_name])
else:
self.sampler = None
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
ret = defaultdict(list)
for path in self.paths:
if _get_quant_path(path).exists():
ret[_get_spkr_name(path)].append(path)
for k, v in extra_paths_by_spkr_name.items():
ret[k].extend(v)
return {**ret}
@cached_property
def phones(self):
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
def _get_phone_symmap(self):
# Note that we use phone symmap starting from 1 so that we can safely pad 0.
return {s: i for i, s in enumerate(self.phones, 1)}
@cached_property
def spkrs(self):
return sorted({_get_spkr_name(path) for path in self.paths})
def _get_spkr_symmap(self):
return {s: i for i, s in enumerate(self.spkrs)}
def sample_prompts(self, spkr_name):
prom_list = []
while (
len(prom_list) == 0
or random.random() < cfg.p_additional_prompt
and len(prom_list) < 10
):
path = random.choice(self.paths_by_spkr_name[spkr_name])
prom_list.append(_load_quants(path))
prom = torch.cat(prom_list)
return prom
def __getitem__(self, index):
if self.training:
assert self.sampler is not None
path = self.sampler.sample()
else:
path = self.paths[index]
spkr_name = _get_spkr_name(path)
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))])
proms = self.sample_prompts(spkr_name)
resps = _load_quants(path)
resp = resps[..., 0]
return dict(
path=path,
spkr_name=spkr_name,
text=text,
proms=proms,
resps=resps,
resp=resp,
)
def head_(self, n):
self._head = n
def training_(self, value):
self.training = value
def interleaved_reorder_(self, fn):
self.paths = [*_interleaved_reorder(self.paths, fn)]
def __len__(self):
return min(len(self.paths), self._head or len(self.paths))
def collate_fn(samples: list[dict]):
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
return batch
def _seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def _create_dl(dataset, training):
return DataLoader(
dataset=dataset,
batch_size=cfg.batch_size if training else cfg.eval_batch_size,
shuffle=training,
drop_last=training,
num_workers=cfg.nj,
collate_fn=collate_fn,
persistent_workers=True,
worker_init_fn=_seed_worker,
)
def _load_train_val_paths():
paths = []
train_paths = []
val_paths = []
for data_dir in cfg.data_dirs:
paths.extend(tqdm(data_dir.rglob("**/*.qnt.pt")))
if len(paths) == 0:
raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.data_dirs}.")
pairs = sorted([(_get_spkr_name(p), p) for p in paths])
del paths
for _, group in groupby(pairs, lambda pair: pair[0]):
paths = sorted([p for _, p in group])
random.seed(0)
random.shuffle(paths)
n = round(len(paths) * 0.95)
train_paths.extend(paths[:n])
val_paths.extend(paths[n:])
train_paths, val_paths = map(sorted, [train_paths, val_paths])
return train_paths, val_paths
def _load_test_paths():
test_paths = []
for data_dir in cfg.test_data_dirs:
test_paths.extend(data_dir.rglob("**/*.asr.txt"))
test_paths = sorted(test_paths)
return test_paths
@cfg.diskcache()
def create_datasets():
train_paths, val_paths = _load_train_val_paths()
test_paths = _load_test_paths()
train_dataset = VALLEDatset(train_paths, training=True)
val_dataset = VALLEDatset(
val_paths,
train_dataset.phone_symmap,
train_dataset.spkr_symmap,
extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
)
val_dataset.interleaved_reorder_(_get_spkr_name)
val_dataset.head_(200)
test_dataset = VALLEDatset(
test_paths,
train_dataset.phone_symmap,
train_dataset.spkr_symmap,
extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
)
return train_dataset, val_dataset, test_dataset
def create_train_val_dataloader():
train_dataset, val_dataset, test_dataset = create_datasets()
train_dl = _create_dl(train_dataset, training=True)
val_dl = _create_dl(val_dataset, training=False)
test_dl = _create_dl(test_dataset, training=False)
_logger.info(str(train_dataset.phone_symmap))
_logger.info(str(train_dataset.spkr_symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (test): {len(test_dataset)}.")
train200_dataset = copy.deepcopy(train_dataset)
train200_dataset.interleaved_reorder_(_get_spkr_name)
train200_dataset.head_(200)
train200_dataset.training_(False)
train200_dl = _create_dl(train200_dataset, training=False)
assert isinstance(train200_dl.dataset, VALLEDatset)
return train_dl, train200_dl, val_dl, test_dl
if __name__ == "__main__":
train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader()
sample = train_dl.dataset[0]
print(sample)