295 lines
8.0 KiB
Python
295 lines
8.0 KiB
Python
import copy
|
|
import logging
|
|
import random
|
|
from collections import defaultdict
|
|
from functools import cache, cached_property
|
|
from itertools import groupby, zip_longest
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from tqdm import tqdm
|
|
|
|
from .config import cfg
|
|
from .sampler import Sampler
|
|
|
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _replace_file_extension(path, suffix):
|
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
|
|
|
|
|
def _get_quant_path(path):
|
|
return _replace_file_extension(path, ".qnt.pt")
|
|
|
|
|
|
def _load_quants(path) -> Tensor:
|
|
"""
|
|
Returns:
|
|
quants: (t q)
|
|
"""
|
|
path = _get_quant_path(path)
|
|
return torch.load(path)[0].t()
|
|
|
|
|
|
@cache
|
|
def _get_phones(path):
|
|
path = _replace_file_extension(path, ".phn.txt")
|
|
with open(path, "r", encoding="utf8") as f:
|
|
content = f.read()
|
|
return ["<s>"] + content.split() + ["</s>"]
|
|
|
|
|
|
def _interleaved_reorder(l, fn):
|
|
groups = defaultdict(list)
|
|
for e in l:
|
|
groups[fn(e)].append(e)
|
|
groups = {k: groups[k] for k in sorted(groups)}
|
|
for interleaved in zip_longest(*groups.values()):
|
|
for value in interleaved:
|
|
if value is not None:
|
|
yield value
|
|
|
|
|
|
@cache
|
|
def _validate(path, min_phones, max_phones):
|
|
phones = _get_phones(path)
|
|
unique_phones = list(set(phones))
|
|
if len(unique_phones) == 0:
|
|
return False
|
|
if len(unique_phones) == 1 and unique_phones[0] == "_":
|
|
return False
|
|
if len(phones) < min_phones:
|
|
return False
|
|
if len(phones) > max_phones:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _get_spkr_name(path) -> str:
|
|
return path.parts[-2] # spkr/*.wav
|
|
|
|
|
|
class VALLEDatset(Dataset):
|
|
def __init__(
|
|
self,
|
|
paths,
|
|
phone_symmap=None,
|
|
spkr_symmap=None,
|
|
min_phones=10,
|
|
max_phones=100,
|
|
training=False,
|
|
extra_paths_by_spkr_name: dict[str, list] = {},
|
|
):
|
|
super().__init__()
|
|
self._head = None
|
|
self.min_phones = min_phones
|
|
self.max_phones = max_phones
|
|
self.paths = [
|
|
path for path in paths if _validate(path, self.min_phones, self.max_phones)
|
|
]
|
|
self.spkr_symmap = spkr_symmap or self._get_spkr_symmap()
|
|
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
|
self.training = training
|
|
|
|
self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
|
|
|
|
if training:
|
|
self.sampler = Sampler(self.paths, [_get_spkr_name])
|
|
else:
|
|
self.sampler = None
|
|
|
|
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
|
|
ret = defaultdict(list)
|
|
for path in self.paths:
|
|
if _get_quant_path(path).exists():
|
|
ret[_get_spkr_name(path)].append(path)
|
|
for k, v in extra_paths_by_spkr_name.items():
|
|
ret[k].extend(v)
|
|
return {**ret}
|
|
|
|
@cached_property
|
|
def phones(self):
|
|
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
|
|
|
def _get_phone_symmap(self):
|
|
# Note that we use phone symmap starting from 1 so that we can safely pad 0.
|
|
return {s: i for i, s in enumerate(self.phones, 1)}
|
|
|
|
@cached_property
|
|
def spkrs(self):
|
|
return sorted({_get_spkr_name(path) for path in self.paths})
|
|
|
|
def _get_spkr_symmap(self):
|
|
return {s: i for i, s in enumerate(self.spkrs)}
|
|
|
|
def sample_prompts(self, spkr_name):
|
|
prom_list = []
|
|
|
|
while (
|
|
len(prom_list) == 0
|
|
or random.random() < cfg.p_additional_prompt
|
|
and len(prom_list) < 10
|
|
):
|
|
path = random.choice(self.paths_by_spkr_name[spkr_name])
|
|
prom_list.append(_load_quants(path))
|
|
|
|
prom = torch.cat(prom_list)
|
|
|
|
return prom
|
|
|
|
def __getitem__(self, index):
|
|
if self.training:
|
|
assert self.sampler is not None
|
|
path = self.sampler.sample()
|
|
else:
|
|
path = self.paths[index]
|
|
|
|
spkr_name = _get_spkr_name(path)
|
|
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))])
|
|
proms = self.sample_prompts(spkr_name)
|
|
resps = _load_quants(path)
|
|
resp = resps[..., 0]
|
|
|
|
return dict(
|
|
path=path,
|
|
spkr_name=spkr_name,
|
|
text=text,
|
|
proms=proms,
|
|
resps=resps,
|
|
resp=resp,
|
|
)
|
|
|
|
def head_(self, n):
|
|
self._head = n
|
|
|
|
def training_(self, value):
|
|
self.training = value
|
|
|
|
def interleaved_reorder_(self, fn):
|
|
self.paths = [*_interleaved_reorder(self.paths, fn)]
|
|
|
|
def __len__(self):
|
|
return min(len(self.paths), self._head or len(self.paths))
|
|
|
|
|
|
def collate_fn(samples: list[dict]):
|
|
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
|
return batch
|
|
|
|
|
|
def _seed_worker(worker_id):
|
|
worker_seed = torch.initial_seed() % 2**32
|
|
np.random.seed(worker_seed)
|
|
random.seed(worker_seed)
|
|
|
|
|
|
def _create_dl(dataset, training):
|
|
return DataLoader(
|
|
dataset=dataset,
|
|
batch_size=cfg.batch_size if training else cfg.eval_batch_size,
|
|
shuffle=training,
|
|
drop_last=training,
|
|
num_workers=cfg.nj,
|
|
collate_fn=collate_fn,
|
|
persistent_workers=True,
|
|
worker_init_fn=_seed_worker,
|
|
)
|
|
|
|
|
|
def _load_train_val_paths():
|
|
paths = []
|
|
train_paths = []
|
|
val_paths = []
|
|
|
|
for data_dir in cfg.data_dirs:
|
|
paths.extend(tqdm(data_dir.rglob("**/*.qnt.pt")))
|
|
|
|
if len(paths) == 0:
|
|
raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.data_dirs}.")
|
|
|
|
pairs = sorted([(_get_spkr_name(p), p) for p in paths])
|
|
del paths
|
|
|
|
for _, group in groupby(pairs, lambda pair: pair[0]):
|
|
paths = sorted([p for _, p in group])
|
|
random.seed(0)
|
|
random.shuffle(paths)
|
|
n = round(len(paths) * 0.95)
|
|
train_paths.extend(paths[:n])
|
|
val_paths.extend(paths[n:])
|
|
|
|
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
|
|
|
return train_paths, val_paths
|
|
|
|
|
|
def _load_test_paths():
|
|
test_paths = []
|
|
for data_dir in cfg.test_data_dirs:
|
|
test_paths.extend(data_dir.rglob("**/*.asr.txt"))
|
|
test_paths = sorted(test_paths)
|
|
return test_paths
|
|
|
|
|
|
@cfg.diskcache()
|
|
def create_datasets():
|
|
train_paths, val_paths = _load_train_val_paths()
|
|
test_paths = _load_test_paths()
|
|
|
|
train_dataset = VALLEDatset(train_paths, training=True)
|
|
|
|
val_dataset = VALLEDatset(
|
|
val_paths,
|
|
train_dataset.phone_symmap,
|
|
train_dataset.spkr_symmap,
|
|
extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
|
)
|
|
|
|
val_dataset.interleaved_reorder_(_get_spkr_name)
|
|
val_dataset.head_(200)
|
|
|
|
test_dataset = VALLEDatset(
|
|
test_paths,
|
|
train_dataset.phone_symmap,
|
|
train_dataset.spkr_symmap,
|
|
extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
|
)
|
|
|
|
return train_dataset, val_dataset, test_dataset
|
|
|
|
|
|
def create_train_val_dataloader():
|
|
train_dataset, val_dataset, test_dataset = create_datasets()
|
|
|
|
train_dl = _create_dl(train_dataset, training=True)
|
|
val_dl = _create_dl(val_dataset, training=False)
|
|
test_dl = _create_dl(test_dataset, training=False)
|
|
|
|
_logger.info(str(train_dataset.phone_symmap))
|
|
_logger.info(str(train_dataset.spkr_symmap))
|
|
|
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
|
_logger.info(f"#samples (test): {len(test_dataset)}.")
|
|
|
|
train200_dataset = copy.deepcopy(train_dataset)
|
|
train200_dataset.interleaved_reorder_(_get_spkr_name)
|
|
train200_dataset.head_(200)
|
|
train200_dataset.training_(False)
|
|
train200_dl = _create_dl(train200_dataset, training=False)
|
|
assert isinstance(train200_dl.dataset, VALLEDatset)
|
|
|
|
return train_dl, train200_dl, val_dl, test_dl
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader()
|
|
sample = train_dl.dataset[0]
|
|
print(sample)
|