2023-08-02 21:53:35 +00:00
|
|
|
|
# todo: clean this mess up
|
|
|
|
|
|
|
|
|
|
import copy
|
|
|
|
|
import h5py
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import numpy as np
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from .config import cfg
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from functools import cache, cached_property
|
|
|
|
|
from itertools import groupby, zip_longest
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
from torch.utils.data import DataLoader, Dataset as _Dataset
|
2023-08-14 03:07:45 +00:00
|
|
|
|
from torch.utils.data.distributed import DistributedSampler
|
2023-08-02 21:53:35 +00:00
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
|
|
# torch.multiprocessing.set_sharing_strategy("file_system")
|
|
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
def get_phone_symmap():
|
|
|
|
|
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
|
|
|
|
|
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
|
|
|
|
|
|
|
|
|
|
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
|
|
|
|
return symmap
|
|
|
|
|
|
|
|
|
|
def _replace_file_extension(path, suffix):
|
|
|
|
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
|
|
|
|
|
|
|
|
|
def _get_hdf5_path(path):
|
|
|
|
|
path = str(path)
|
|
|
|
|
if path[:2] != "./":
|
|
|
|
|
path = f'./{path}'
|
|
|
|
|
return path.replace(cfg.cfg_path, "")
|
|
|
|
|
|
|
|
|
|
def _get_quant_path(path):
|
|
|
|
|
return _replace_file_extension(path, ".qnt.pt")
|
|
|
|
|
|
|
|
|
|
def _get_phone_path(path):
|
|
|
|
|
return _replace_file_extension(path, ".phn.txt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_quants(path) -> Tensor:
|
|
|
|
|
path = _get_quant_path(path)
|
2023-08-19 01:58:07 +00:00
|
|
|
|
return torch.load(path)[0][:cfg.models.prom_levels, :].t().to(torch.int16)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@cache
|
|
|
|
|
def _get_phones(path, lang_marker="en"):
|
|
|
|
|
path = _get_phone_path(path)
|
|
|
|
|
with open(path, "r", encoding="utf8") as f:
|
|
|
|
|
content = f.read()
|
|
|
|
|
split = content.split(" ")
|
|
|
|
|
return [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _interleaved_reorder(l, fn):
|
|
|
|
|
groups = defaultdict(list)
|
|
|
|
|
for e in l:
|
|
|
|
|
groups[fn(e)].append(e)
|
|
|
|
|
groups = {k: groups[k] for k in sorted(groups)}
|
|
|
|
|
for interleaved in zip_longest(*groups.values()):
|
|
|
|
|
for value in interleaved:
|
|
|
|
|
if value is not None:
|
|
|
|
|
yield value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@cache
|
|
|
|
|
def _validate(path, min_phones, max_phones, min_duration, max_duration):
|
|
|
|
|
if cfg.dataset.use_hdf5:
|
|
|
|
|
key = _get_hdf5_path(path)
|
|
|
|
|
if key not in cfg.hdf5:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
phones = cfg.hdf5[key].attrs['phonemes']
|
|
|
|
|
duration = cfg.hdf5[key].attrs['duration']
|
|
|
|
|
|
|
|
|
|
if phones < min_phones or phones > max_phones:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if duration < min_duration or duration > max_duration:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(_get_phone_path(path)) or not os.path.exists(_get_quant_path(path)):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
phones = _get_phones(path)
|
|
|
|
|
unique_phones = list(set(phones))
|
|
|
|
|
|
|
|
|
|
if len(unique_phones) == 0:
|
|
|
|
|
return False
|
|
|
|
|
if len(unique_phones) == 1 and unique_phones[0] == " ":
|
|
|
|
|
return False
|
|
|
|
|
if len(phones) < min_phones or len(phones) > max_phones:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
class Dataset(_Dataset):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
paths,
|
|
|
|
|
phone_symmap=None,
|
|
|
|
|
spkr_symmap=None,
|
|
|
|
|
min_phones=cfg.dataset.phones_range[0],
|
|
|
|
|
max_phones=cfg.dataset.phones_range[1],
|
|
|
|
|
min_duration=cfg.dataset.duration_range[0],
|
|
|
|
|
max_duration=cfg.dataset.duration_range[1],
|
|
|
|
|
training=False,
|
|
|
|
|
extra_paths_by_spkr_name: dict[str, list] = {},
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self._head = None
|
|
|
|
|
self.min_phones = min_phones
|
|
|
|
|
self.max_phones = max_phones
|
|
|
|
|
self.min_duration = min_duration
|
|
|
|
|
self.max_duration = max_duration
|
2023-08-19 01:58:07 +00:00
|
|
|
|
self.sampler = None
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
|
|
if cfg.dataset.validate:
|
|
|
|
|
self.paths = [
|
|
|
|
|
path for path in paths if _validate(path, self.min_phones, self.max_phones, self.min_duration, self.max_duration)
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
self.paths = paths
|
|
|
|
|
|
|
|
|
|
self.spkr_symmap = spkr_symmap or self._get_spkr_symmap()
|
|
|
|
|
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
|
|
|
|
self.training = training
|
|
|
|
|
|
|
|
|
|
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
|
|
|
|
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
|
|
|
|
|
|
|
|
|
self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
|
|
|
|
|
|
|
|
|
|
if cfg.dataset.validate:
|
|
|
|
|
self.paths = [
|
|
|
|
|
p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
|
|
|
|
|
]
|
|
|
|
|
|
2023-08-19 01:58:07 +00:00
|
|
|
|
if cfg.dataset.sample_type == "path":
|
|
|
|
|
self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)]
|
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
|
if len(self.paths) == 0 and training:
|
|
|
|
|
raise ValueError("No valid path is found for training.")
|
|
|
|
|
|
|
|
|
|
self.duration = 0
|
|
|
|
|
self.durations = {}
|
|
|
|
|
if cfg.dataset.use_hdf5:
|
|
|
|
|
for path in self.paths:
|
|
|
|
|
key = _get_hdf5_path(path)
|
|
|
|
|
spkr_name = cfg.get_spkr(path)
|
|
|
|
|
spkr_id = self.spkr_symmap[spkr_name]
|
|
|
|
|
duration = cfg.hdf5[key].attrs['duration']
|
|
|
|
|
|
|
|
|
|
self.duration += duration
|
|
|
|
|
|
|
|
|
|
if spkr_id not in self.durations:
|
|
|
|
|
self.durations[spkr_id] = duration
|
|
|
|
|
else:
|
|
|
|
|
self.durations[spkr_id] += duration
|
|
|
|
|
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
|
|
|
|
|
ret = defaultdict(list)
|
|
|
|
|
for path in self.paths:
|
|
|
|
|
ret[cfg.get_spkr(path)].append(path)
|
|
|
|
|
for k, v in extra_paths_by_spkr_name.items():
|
|
|
|
|
ret[k].extend(v)
|
|
|
|
|
return {**ret}
|
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def phones(self):
|
|
|
|
|
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
|
|
|
|
|
|
|
|
|
def _get_phone_symmap(self):
|
|
|
|
|
return get_phone_symmap()
|
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def spkrs(self):
|
|
|
|
|
return sorted({cfg.get_spkr(path) for path in self.paths})
|
|
|
|
|
|
|
|
|
|
def _get_spkr_symmap(self):
|
|
|
|
|
return {s: i for i, s in enumerate(self.spkrs)}
|
|
|
|
|
|
2023-08-18 00:07:59 +00:00
|
|
|
|
def sample_speakers(self, ignore=[]):
|
|
|
|
|
choices = set(self.spkrs) - set(ignore)
|
|
|
|
|
return random.choice([*choices])
|
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
|
def sample_prompts(self, spkr_name, ignore):
|
|
|
|
|
prom_list = []
|
|
|
|
|
|
|
|
|
|
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
|
|
|
|
|
choices = [*choices]
|
|
|
|
|
|
2023-08-17 00:39:21 +00:00
|
|
|
|
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step
|
2023-08-02 21:53:35 +00:00
|
|
|
|
if len(choices) == 0:
|
2023-08-17 00:39:21 +00:00
|
|
|
|
choices = [*set(self.paths_by_spkr_name[spkr_name])]
|
|
|
|
|
"""
|
2023-08-02 21:53:35 +00:00
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Failed to find another different utterance for {spkr_name}."
|
|
|
|
|
)
|
2023-08-17 00:39:21 +00:00
|
|
|
|
"""
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
|
|
# shuffle it up a bit
|
|
|
|
|
offset = random.randint(-16, 16)
|
|
|
|
|
trim_length = int(cfg.dataset.prompt_duration * 75) + offset
|
|
|
|
|
def trim( qnt ):
|
|
|
|
|
length = qnt.shape[0]
|
|
|
|
|
start = int(length * random.random())
|
|
|
|
|
end = start + trim_length
|
|
|
|
|
if end >= length:
|
|
|
|
|
start = length - trim_length
|
|
|
|
|
end = length
|
|
|
|
|
|
|
|
|
|
return qnt[start:end]
|
|
|
|
|
|
|
|
|
|
total_qnt_length = 0
|
|
|
|
|
for _ in range(cfg.dataset.max_prompts):
|
|
|
|
|
path = random.choice(choices)
|
|
|
|
|
if cfg.dataset.use_hdf5:
|
|
|
|
|
key = _get_hdf5_path(path)
|
|
|
|
|
#qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
|
2023-08-19 01:58:07 +00:00
|
|
|
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
else:
|
|
|
|
|
qnt = _load_quants(path)
|
|
|
|
|
|
|
|
|
|
if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]:
|
|
|
|
|
qnt = trim(qnt)
|
|
|
|
|
|
|
|
|
|
prom_list.append(qnt)
|
|
|
|
|
total_qnt_length += qnt.shape[0]
|
|
|
|
|
|
|
|
|
|
if total_qnt_length >= trim_length:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if random.random() > cfg.dataset.random_utterance:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
prom = torch.cat(prom_list)
|
|
|
|
|
|
|
|
|
|
if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]:
|
|
|
|
|
prom = trim(prom)
|
|
|
|
|
|
|
|
|
|
return prom
|
|
|
|
|
|
2023-08-17 23:56:37 +00:00
|
|
|
|
@cached_property
|
|
|
|
|
def tasks(self):
|
|
|
|
|
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
|
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
|
def __getitem__(self, index):
|
2023-08-19 01:58:07 +00:00
|
|
|
|
if cfg.dataset.sample_type == "speaker":
|
2023-08-17 00:39:21 +00:00
|
|
|
|
spkr_name = self.spkrs[index]
|
|
|
|
|
spkr_id = self.spkr_symmap[spkr_name]
|
|
|
|
|
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
2023-08-02 21:53:35 +00:00
|
|
|
|
else:
|
2023-08-18 19:47:48 +00:00
|
|
|
|
path = self.paths[index]
|
2023-08-17 00:39:21 +00:00
|
|
|
|
spkr_name = cfg.get_spkr(path)
|
|
|
|
|
spkr_id = self.spkr_symmap[spkr_name]
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
|
|
if cfg.dataset.use_hdf5:
|
|
|
|
|
key = _get_hdf5_path(path)
|
|
|
|
|
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(self.text_dtype)
|
2023-08-19 01:58:07 +00:00
|
|
|
|
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
else:
|
|
|
|
|
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
|
|
|
|
resps = _load_quants(path)
|
|
|
|
|
|
2023-08-17 23:56:37 +00:00
|
|
|
|
task = random.choice(self.tasks)
|
2023-08-18 00:07:59 +00:00
|
|
|
|
# text-to-speech
|
2023-08-17 23:56:37 +00:00
|
|
|
|
if task == "tts":
|
|
|
|
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
2023-08-18 00:07:59 +00:00
|
|
|
|
# noise-suppression
|
|
|
|
|
"""
|
|
|
|
|
elif task == "ns":
|
|
|
|
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
|
|
|
|
noise = self.sample_noise()
|
|
|
|
|
noise = extend_audio(noise, proms.shape[0])
|
|
|
|
|
proms = merge_audio(proms, noise)
|
|
|
|
|
# something to prepend a ns token to the beginning of proms
|
|
|
|
|
elif task == "sr":
|
|
|
|
|
proms = resps
|
|
|
|
|
resps = self.sample_noise()
|
|
|
|
|
resps = extend_audio(resps, proms.shape[0])
|
|
|
|
|
# something to prepend a sr token to the beginning of proms
|
2023-08-18 19:47:48 +00:00
|
|
|
|
elif task == "tse":
|
2023-08-18 00:07:59 +00:00
|
|
|
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
|
|
|
|
other_speaker = self.sample_speaker(ignore=[spkr_name])
|
|
|
|
|
other_proms = self.sample_prompts(other_speaker, ignore="")
|
|
|
|
|
proms = merge_audio(proms, other_proms)
|
2023-08-18 19:47:48 +00:00
|
|
|
|
# something to prepend a tse token to the beginning of proms
|
|
|
|
|
"""
|
|
|
|
|
"""
|
|
|
|
|
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
|
|
|
|
|
# as I need to get a good clean point to trim into
|
|
|
|
|
elif task == "cse":
|
|
|
|
|
elif task == "nse":
|
2023-08-18 00:07:59 +00:00
|
|
|
|
"""
|
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
|
|
return dict(
|
|
|
|
|
index=index,
|
|
|
|
|
path=path,
|
|
|
|
|
spkr_name=spkr_name,
|
|
|
|
|
spkr_id=spkr_id,
|
2023-08-17 23:56:37 +00:00
|
|
|
|
task=task,
|
2023-08-02 21:53:35 +00:00
|
|
|
|
text=text,
|
|
|
|
|
proms=proms,
|
|
|
|
|
resps=resps,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def head_(self, n):
|
|
|
|
|
self._head = n
|
|
|
|
|
|
|
|
|
|
def training_(self, value):
|
|
|
|
|
self.training = value
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
2023-08-19 01:58:07 +00:00
|
|
|
|
if cfg.dataset.sample_type == "speaker":
|
2023-08-17 00:39:21 +00:00
|
|
|
|
return min(len(self.spkrs), self._head or len(self.spkrs))
|
2023-08-02 21:53:35 +00:00
|
|
|
|
return min(len(self.paths), self._head or len(self.paths))
|
|
|
|
|
|
|
|
|
|
def pin_memory(self):
|
|
|
|
|
self.text = self.text.pin_memory()
|
|
|
|
|
self.proms = self.proms.pin_memory()
|
|
|
|
|
self.resps = self.resps.pin_memory()
|
|
|
|
|
self.resp = self.resp.pin_memory()
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collate_fn(samples: list[dict]):
|
|
|
|
|
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
|
|
|
|
return batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _seed_worker(worker_id):
|
|
|
|
|
worker_seed = torch.initial_seed() % 2**32
|
|
|
|
|
np.random.seed(worker_seed)
|
|
|
|
|
random.seed(worker_seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_dataloader(dataset, training):
|
2023-08-17 18:41:53 +00:00
|
|
|
|
sampler = None
|
|
|
|
|
shuffle = True
|
|
|
|
|
|
|
|
|
|
if cfg.distributed and training:
|
|
|
|
|
sampler = DistributedSampler(dataset)
|
|
|
|
|
shuffle = False
|
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
|
return DataLoader(
|
|
|
|
|
dataset=dataset,
|
|
|
|
|
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
2023-08-17 18:41:53 +00:00
|
|
|
|
shuffle=shuffle,
|
2023-08-02 21:53:35 +00:00
|
|
|
|
drop_last=training,
|
|
|
|
|
num_workers=cfg.dataset.workers,
|
|
|
|
|
collate_fn=collate_fn,
|
|
|
|
|
persistent_workers=True,
|
|
|
|
|
pin_memory=False, # True,
|
|
|
|
|
worker_init_fn=_seed_worker,
|
2023-08-17 18:41:53 +00:00
|
|
|
|
sampler=sampler,
|
2023-08-02 21:53:35 +00:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _load_dataset_paths():
|
|
|
|
|
hf = cfg.hdf5
|
|
|
|
|
paths = {
|
|
|
|
|
"training": [],
|
|
|
|
|
"validation": [],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
datasets = {
|
|
|
|
|
"training": [],
|
|
|
|
|
"validation": [],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def get_paths( data_dir, type="training" ):
|
|
|
|
|
key = f"/{type}{_get_hdf5_path(data_dir)}"
|
|
|
|
|
if key not in cfg.hdf5:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ])
|
|
|
|
|
|
|
|
|
|
for data_dir in cfg.dataset.training:
|
|
|
|
|
get_paths( data_dir, "training" )
|
|
|
|
|
|
|
|
|
|
for data_dir in cfg.dataset.validation:
|
|
|
|
|
get_paths( data_dir, "validation" )
|
|
|
|
|
|
|
|
|
|
for _, type in enumerate(paths):
|
|
|
|
|
dirs = paths[type]
|
|
|
|
|
|
|
|
|
|
if len(dirs) == 0:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
dirs = [ Path(p) for p in dirs ]
|
|
|
|
|
|
|
|
|
|
pairs = sorted([(cfg.get_spkr(p), p) for p in dirs])
|
|
|
|
|
for _, group in groupby(pairs, lambda pair: pair[0]):
|
|
|
|
|
shuffled = sorted([p for _, p in group])
|
|
|
|
|
random.seed(0)
|
|
|
|
|
random.shuffle(shuffled)
|
|
|
|
|
|
|
|
|
|
datasets[type].extend(shuffled)
|
|
|
|
|
|
|
|
|
|
return datasets["training"], datasets["validation"]
|
|
|
|
|
|
2023-08-17 23:56:37 +00:00
|
|
|
|
# to-do: mirror the hdf5-based load function
|
2023-08-02 21:53:35 +00:00
|
|
|
|
def _load_train_val_paths():
|
|
|
|
|
paths = []
|
|
|
|
|
train_paths = []
|
|
|
|
|
val_paths = []
|
|
|
|
|
|
|
|
|
|
for data_dir in cfg.dataset.training:
|
|
|
|
|
paths.extend(data_dir.rglob("*.qnt.pt"))
|
|
|
|
|
|
|
|
|
|
if len(paths) > 0:
|
|
|
|
|
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
|
|
|
|
|
del paths
|
|
|
|
|
|
|
|
|
|
for _, group in groupby(pairs, lambda pair: pair[0]):
|
|
|
|
|
paths = sorted([p for _, p in group])
|
|
|
|
|
random.seed(0)
|
|
|
|
|
random.shuffle(paths)
|
|
|
|
|
train_paths.extend(paths)
|
|
|
|
|
|
|
|
|
|
for data_dir in cfg.dataset.validation:
|
|
|
|
|
paths.extend(data_dir.rglob("*.qnt.pt"))
|
|
|
|
|
|
|
|
|
|
if len(paths) > 0:
|
|
|
|
|
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
|
|
|
|
|
del paths
|
|
|
|
|
|
|
|
|
|
for _, group in groupby(pairs, lambda pair: pair[0]):
|
|
|
|
|
paths = sorted([p for _, p in group])
|
|
|
|
|
random.seed(0)
|
|
|
|
|
random.shuffle(paths)
|
|
|
|
|
val_paths.extend(paths)
|
|
|
|
|
|
|
|
|
|
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
|
|
|
|
|
|
|
|
|
if len(train_paths) == 0:
|
2023-08-17 23:56:37 +00:00
|
|
|
|
raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.")
|
|
|
|
|
|
|
|
|
|
# to-do: allow setting aside a fixed portion of the training dataset for validation
|
|
|
|
|
# something like the last X percent of each speaker is set aside
|
2023-08-02 21:53:35 +00:00
|
|
|
|
if len(val_paths) == 0:
|
|
|
|
|
val_paths = [ train_paths[0] ]
|
|
|
|
|
|
|
|
|
|
return train_paths, val_paths
|
|
|
|
|
|
|
|
|
|
@cfg.diskcache()
|
|
|
|
|
def create_datasets():
|
|
|
|
|
train_paths, val_paths = _load_dataset_paths() if cfg.dataset.use_hdf5 else _load_train_val_paths()
|
|
|
|
|
|
|
|
|
|
train_dataset = Dataset(
|
|
|
|
|
train_paths,
|
|
|
|
|
training=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
val_dataset = Dataset(
|
|
|
|
|
val_paths,
|
|
|
|
|
train_dataset.phone_symmap,
|
|
|
|
|
#train_dataset.spkr_symmap,
|
|
|
|
|
#extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
val_dataset.head_(cfg.evaluation.size)
|
|
|
|
|
|
|
|
|
|
return train_dataset, val_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_train_val_dataloader():
|
|
|
|
|
train_dataset, val_dataset = create_datasets()
|
|
|
|
|
|
|
|
|
|
subtrain_dataset = copy.deepcopy(train_dataset)
|
2023-08-19 01:58:07 +00:00
|
|
|
|
if cfg.dataset.sample_type == "path":
|
2023-08-17 20:04:45 +00:00
|
|
|
|
subtrain_dataset.head_(cfg.evaluation.size)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
|
|
train_dl = _create_dataloader(train_dataset, training=True)
|
|
|
|
|
val_dl = _create_dataloader(val_dataset, training=False)
|
|
|
|
|
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
|
|
|
|
|
|
|
|
|
|
_logger.info(str(train_dataset.phone_symmap))
|
|
|
|
|
_logger.info(str(train_dataset.spkr_symmap))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
|
|
|
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
|
|
|
|
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
_logger.info(f"#durations (train): {str(train_dataset.durations)}.")
|
|
|
|
|
_logger.info(f"#durations (val): {str(val_dataset.durations)}.")
|
|
|
|
|
_logger.info(f"#durations (subtrain): {str(subtrain_dataset.durations)}.")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
|
|
|
|
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
|
|
|
|
_logger.info(f"#duration (subtrain): {str(subtrain_dataset.duration)}.")
|
|
|
|
|
|
|
|
|
|
assert isinstance(subtrain_dl.dataset, Dataset)
|
|
|
|
|
|
|
|
|
|
return train_dl, subtrain_dl, val_dl
|
|
|
|
|
|
|
|
|
|
# parse yaml to create an hdf5 tile
|
|
|
|
|
def create_dataset_hdf5():
|
|
|
|
|
symmap = get_phone_symmap()
|
|
|
|
|
|
|
|
|
|
root = cfg.cfg_path
|
|
|
|
|
hf = cfg.hdf5
|
|
|
|
|
|
|
|
|
|
def add( dir, type="training" ):
|
|
|
|
|
dir = "./" + str(dir)
|
|
|
|
|
name = dir.replace(root, "")
|
|
|
|
|
|
|
|
|
|
print( str(dir), name )
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(f'{root}/{name}/'):
|
|
|
|
|
return
|
|
|
|
|
# tqdm.write(f'{root}/{name}')
|
|
|
|
|
files = os.listdir(f'{root}/{name}/')
|
|
|
|
|
|
|
|
|
|
# grab IDs for every file
|
|
|
|
|
ids = { ".".join(file.split(".")[:-2]) for file in files }
|
|
|
|
|
for id in tqdm(ids, desc=f"Processing {name}"):
|
|
|
|
|
if not os.path.exists(f'{root}/{name}/{id}.qnt.pt') or not os.path.exists(f'{root}/{name}/{id}.phn.txt'):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
key = f'{type}/{name}/{id}'
|
|
|
|
|
if key in hf:
|
|
|
|
|
# print("Skipping existing entry:", key)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
group = hf.create_group(key)
|
|
|
|
|
|
|
|
|
|
# audio
|
|
|
|
|
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
|
|
|
|
|
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
|
|
|
|
|
|
|
|
|
|
# text
|
|
|
|
|
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
|
|
|
|
|
content = f.read()
|
|
|
|
|
split = content.split(" ")
|
|
|
|
|
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
|
|
|
|
for s in set(phones):
|
|
|
|
|
if s not in symmap:
|
|
|
|
|
symmap[s] = len(symmap.keys())
|
|
|
|
|
phn = [ symmap[s] for s in phones ]
|
|
|
|
|
|
|
|
|
|
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
|
|
|
|
|
|
|
|
|
# metadata
|
|
|
|
|
group.attrs['id'] = id
|
|
|
|
|
group.attrs['type'] = type
|
|
|
|
|
group.attrs['speaker'] = name
|
|
|
|
|
group.attrs['duration'] = qnt.shape[0] / 75
|
|
|
|
|
group.attrs['phonemes'] = len(phn)
|
|
|
|
|
|
|
|
|
|
# training
|
|
|
|
|
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
|
|
|
|
add( data_dir, type="training" )
|
|
|
|
|
|
|
|
|
|
# validation
|
|
|
|
|
for data_dir in tqdm(cfg.dataset.validation, desc='Processing Validation'):
|
|
|
|
|
add( data_dir, type="validation" )
|
|
|
|
|
|
|
|
|
|
# write symmap
|
|
|
|
|
hf.create_dataset('symmap', data=json.dumps(symmap))
|
|
|
|
|
|
|
|
|
|
hf.close()
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2023-08-17 20:04:45 +00:00
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
|
|
|
|
parser.add_argument("--create-hdf5", action="store_true")
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
if args.create_hdf5:
|
|
|
|
|
create_dataset_hdf5()
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
|
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
2023-08-18 19:47:48 +00:00
|
|
|
|
|
|
|
|
|
samples = {
|
|
|
|
|
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
|
|
|
|
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
|
|
|
|
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for k, v in samples.items():
|
|
|
|
|
for i in range(len(v)):
|
|
|
|
|
del v[i]['proms']
|
|
|
|
|
del v[i]['resps']
|
|
|
|
|
print(f'{k}:', v)
|
|
|
|
|
|