overhauled dataloading code to be marginally faster, mostly cleaned up, and can leverage a metadata json to help things out

This commit is contained in:
mrq 2023-08-26 19:53:23 -05:00
parent 7b3be3d7bf
commit 78378ed1ce
6 changed files with 321 additions and 254 deletions

View File

@ -27,6 +27,10 @@ class _Config:
def relpath(self): def relpath(self):
return Path(self.cfg_path) return Path(self.cfg_path)
@property
def cache_dir(self):
return self.relpath / ".cache"
@property @property
def ckpt_dir(self): def ckpt_dir(self):
return self.relpath / "ckpt" return self.relpath / "ckpt"
@ -119,6 +123,7 @@ class Dataset:
hdf5_name: str = "data.h5" hdf5_name: str = "data.h5"
use_hdf5: bool = False use_hdf5: bool = False
use_metadata: bool = False
hdf5_flag: str = "a" hdf5_flag: str = "a"
validate: bool = True validate: bool = True
workers: int = 8 workers: int = 8
@ -135,6 +140,19 @@ class Dataset:
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) tasks_list: list[str] = field(default_factory=lambda: ["tts"])
@property
def min_phones(self):
return self.phones_range[0]
@property
def max_phones(self):
return self.phones_range[1]
@property
def min_duration(self):
return self.duration_range[0]
@property
def max_duration(self):
return self.duration_range[1]
@dataclass() @dataclass()
class Model: class Model:
name: str = "" name: str = ""
@ -393,7 +411,7 @@ class Trainer:
weight_dtype: str = "float16" weight_dtype: str = "float16"
backend: str = "deepspeed" if not sys.platform.startswith("win") else "local" backend: str = "local" # "deepspeed" if not sys.platform.startswith("win") else "local"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
@ -453,10 +471,6 @@ class Config(_Config):
def get_spkr(self): def get_spkr(self):
return eval(self.dataset.speaker_name_getter) return eval(self.dataset.speaker_name_getter)
@property
def cache_dir(self):
return ".cache" / self.relpath
@cached_property @cached_property
def diskcache(self): def diskcache(self):
if self.cfg_path is not None and self.dataset.cache: if self.cfg_path is not None and self.dataset.cache:
@ -501,11 +515,10 @@ try:
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
cfg.load_hdf5() cfg.load_hdf5()
if not cfg.dataset.use_hdf5: cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ] cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ] cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
except Exception as e: except Exception as e:
pass pass

View File

@ -8,6 +8,7 @@ import numpy as np
import os import os
import random import random
import torch import torch
import itertools
from .config import cfg from .config import cfg
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
@ -31,7 +32,7 @@ def get_phone_symmap():
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5: if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
return json.loads( cfg.hdf5['symmap'].asstr()[()] ) return json.loads( cfg.hdf5['symmap'].asstr()[()] )
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '': 179, '': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, '': 184, ':ˈ': 185}
return symmap return symmap
def get_task_symmap(): def get_task_symmap():
@ -51,24 +52,89 @@ def get_task_symmap():
def _replace_file_extension(path, suffix): def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix) return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_hdf5_path(path):
path = str(path)
if path[:2] != "./":
path = f'./{path}'
return path.replace(cfg.cfg_path, "")
def _get_quant_path(path): def _get_quant_path(path):
return _replace_file_extension(path, ".qnt.pt") return _replace_file_extension(path, ".qnt.pt")
def _get_phone_path(path): def _get_phone_path(path):
return _replace_file_extension(path, ".phn.txt") return _replace_file_extension(path, ".phn.txt")
def _load_paths(dataset, type="training"):
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
"""
def _load_paths_from_hdf5(dataset, type="training"):
return { cfg.get_spkr( data_dir / "dummy" ): _get_hdf5_paths( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
def _load_paths_from_disk(dataset, type="training"):
return { cfg.get_spkr( data_dir / "dummy" ): _get_paths_of_extensions( data_dir, ".qnt.pt", validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
"""
def _load_paths_from_metadata(data_dir, type="training", validate=False):
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
def _validate( entry ):
phones = entry['phones']
duration = entry['duration']
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
metadata_path = data_dir / "metadata.json"
if not cfg.dataset.use_metadata or not metadata_path.exists():
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
speaker = cfg.get_spkr( data_dir / "dummy" )
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
def key( dir, id ):
if not cfg.dataset.use_hdf5:
return data_dir / id
return f"/{type}{_get_hdf5_path(data_dir)}/{id}"
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
def _get_hdf5_path(path):
path = str(path)
if path[:2] != "./":
path = f'./{path}'
return path.replace(cfg.cfg_path, "")
def _get_hdf5_paths( data_dir, type="training", validate=False ):
data_dir = str(data_dir)
def _validate(child):
phones = child.attrs['phonemes']
duration = child.attrs['duration']
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
key = f"/{type}{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ):
if isinstance(path, str):
path = Path(path)
def _validate(path):
if "".join(path.suffixes) not in extensions:
return False
if not _get_phone_path(path).exists() or not _get_quant_path(path).exists():
return False
if not validate:
return True
# to-do: find an easy way to determine size from pickled quants without loading
# to-do: find a consistent way to derive phoneme count from filesize (probably can't due to utf-8)
phones = len(_get_phones(_get_phone_path(path))) # _get_phone_path(path).stat().st_size // 2 + 1
return cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
def _load_quants(path) -> Tensor: def _load_quants(path) -> Tensor:
return torch.load(path)[0][:, :].t().to(torch.int16) return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
@cache @cache
def _get_phones(path, language="en"): def _get_phones(path, language="en"):
content = open(_get_phone_path(path), "r", encoding="utf8").read().split(" ") content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"] return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
def _interleaved_reorder(l, fn): def _interleaved_reorder(l, fn):
@ -81,114 +147,61 @@ def _interleaved_reorder(l, fn):
if value is not None: if value is not None:
yield value yield value
@cache
def _validate(path, min_phones, max_phones, min_duration, max_duration):
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if key not in cfg.hdf5:
return False
phones = cfg.hdf5[key].attrs['phonemes']
duration = cfg.hdf5[key].attrs['duration']
if phones < min_phones or phones > max_phones:
return False
if duration < min_duration or duration > max_duration:
return False
return True
if not os.path.exists(_get_phone_path(path)) or not os.path.exists(_get_quant_path(path)):
return False
phones = _get_phones(path)
unique_phones = list(set(phones))
if len(unique_phones) == 0:
return False
if len(unique_phones) == 1 and unique_phones[0] == " ":
return False
if len(phones) < min_phones or len(phones) > max_phones:
return False
return True
class Dataset(_Dataset): class Dataset(_Dataset):
def __init__( def __init__(
self, self,
paths,
phone_symmap=None, phone_symmap=None,
training=False, training=False,
extra_paths_by_spkr_name: dict[str, list] = {}, extra_paths_by_spkr_name: dict[str, list] = {},
): ):
super().__init__() super().__init__()
self._head = None self._head = None
self.min_phones = cfg.dataset.phones_range[0]
self.max_phones = cfg.dataset.phones_range[1]
self.min_duration = cfg.dataset.duration_range[0]
self.max_duration = cfg.dataset.duration_range[1]
self.sampler = None self.sampler = None
if cfg.dataset.validate: self.paths = []
self.paths = [
path for path in paths if _validate(path, self.min_phones, self.max_phones, self.min_duration, self.max_duration) self.training = training
] self.dataset_type = "training" if self.training else "validation"
else: self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.paths = paths
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
if cfg.dataset.sample_type == "path":
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
self.phone_symmap = phone_symmap or self._get_phone_symmap() self.phone_symmap = phone_symmap or self._get_phone_symmap()
self.spkr_symmap = self._get_spkr_symmap() self.spkr_symmap = self._get_spkr_symmap()
self.task_symmap = self._get_task_symmap() self.task_symmap = self._get_task_symmap()
self.training = training
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8" # assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16 self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
if cfg.dataset.validate:
self.paths = [
p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
]
if cfg.dataset.sample_type == "path":
self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)]
if len(self.paths) == 0 and training: if len(self.paths) == 0 and training:
raise ValueError("No valid path is found for training.") raise ValueError("No valid path is found for training.")
# would be a better cost saving if we could fetch the duration during the validation pass but oh well
self.duration = 0 self.duration = 0
self.durations = {}
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
for path in self.paths: for path in self.paths:
key = _get_hdf5_path(path) self.duration += cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
spkr_name = cfg.get_spkr(path)
spkr_id = self.spkr_symmap[spkr_name]
duration = cfg.hdf5[key].attrs['duration']
self.duration += duration
if spkr_id not in self.durations:
self.durations[spkr_id] = duration
else:
self.durations[spkr_id] += duration
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
ret = defaultdict(list)
for path in self.paths:
ret[cfg.get_spkr(path)].append(path)
for k, v in extra_paths_by_spkr_name.items():
ret[k].extend(v)
return {**ret}
@cached_property @cached_property
def phones(self): def phones(self):
return sorted(set().union(*[_get_phones(path) for path in self.paths])) return sorted(set().union(*[_get_phones(path) for path in self.paths]))
def get_speaker(self, path):
if isinstance(path, str):
path = Path(path)
res = cfg.get_spkr(path)
return res
@cached_property @cached_property
def spkrs(self): def spkrs(self):
return sorted({cfg.get_spkr(path) for path in self.paths}) return sorted({self.get_speaker(path) for path in self.paths})
@cached_property @cached_property
def tasks(self): def tasks(self):
@ -209,13 +222,10 @@ class Dataset(_Dataset):
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16) return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
def sample_noise(self): def sample_noise(self):
paths = [] path = random.choice(self.noise_paths)
for data_dir in cfg.dataset.noise:
paths.extend(data_dir.rglob("*.qnt.pt"))
path = random.choice(paths)
if False and cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
key = f'/noise/{_get_hdf5_path(path)}' key = _get_hdf5_path(path)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else: else:
qnt = _load_quants(path) qnt = _load_quants(path)
@ -275,7 +285,7 @@ class Dataset(_Dataset):
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])]) path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
else: else:
path = self.paths[index] path = self.paths[index]
spkr_name = cfg.get_spkr(path) spkr_name = self.get_speaker(path)
spkr_id = self.spkr_symmap[spkr_name] spkr_id = self.spkr_symmap[spkr_name]
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
@ -507,110 +517,10 @@ def _create_dataloader(dataset, training):
sampler=sampler, sampler=sampler,
) )
def _load_dataset_paths():
hf = cfg.hdf5
paths = {
"training": [],
"validation": [],
}
datasets = {
"training": [],
"validation": [],
}
def get_paths( data_dir, type="training" ):
key = f"/{type}{_get_hdf5_path(data_dir)}"
if key not in cfg.hdf5:
return
paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ])
for data_dir in cfg.dataset.training:
get_paths( data_dir, "training" )
for data_dir in cfg.dataset.validation:
get_paths( data_dir, "validation" )
for _, type in enumerate(paths):
dirs = paths[type]
if len(dirs) == 0:
continue
dirs = [ Path(p) for p in dirs ]
pairs = sorted([(cfg.get_spkr(p), p) for p in dirs])
for _, group in groupby(pairs, lambda pair: pair[0]):
shuffled = sorted([p for _, p in group])
random.seed(0)
random.shuffle(shuffled)
datasets[type].extend(shuffled)
return datasets["training"], datasets["validation"]
# to-do: mirror the hdf5-based load function
def _load_train_val_paths():
paths = []
train_paths = []
val_paths = []
for data_dir in cfg.dataset.training:
paths.extend(data_dir.rglob("*.qnt.pt"))
if len(paths) > 0:
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
del paths
for _, group in groupby(pairs, lambda pair: pair[0]):
paths = sorted([p for _, p in group])
random.seed(0)
random.shuffle(paths)
train_paths.extend(paths)
for data_dir in cfg.dataset.validation:
paths.extend(data_dir.rglob("*.qnt.pt"))
if len(paths) > 0:
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
del paths
for _, group in groupby(pairs, lambda pair: pair[0]):
paths = sorted([p for _, p in group])
random.seed(0)
random.shuffle(paths)
val_paths.extend(paths)
train_paths, val_paths = map(sorted, [train_paths, val_paths])
if len(train_paths) == 0:
raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.")
# to-do: allow setting aside a fixed portion of the training dataset for validation
# something like the last X percent of each speaker is set aside
if len(val_paths) == 0:
val_paths = [ train_paths[0] ]
return train_paths, val_paths
@cfg.diskcache() @cfg.diskcache()
def create_datasets(): def create_datasets():
train_paths, val_paths = _load_dataset_paths() if cfg.dataset.use_hdf5 else _load_train_val_paths() train_dataset = Dataset( training=True )
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
train_dataset = Dataset(
train_paths,
training=True,
)
val_dataset = Dataset(
val_paths,
train_dataset.phone_symmap,
#train_dataset.spkr_symmap,
#extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
)
val_dataset.head_(cfg.evaluation.size)
return train_dataset, val_dataset return train_dataset, val_dataset
@ -628,17 +538,10 @@ def create_train_val_dataloader():
_logger.info(str(train_dataset.phone_symmap)) _logger.info(str(train_dataset.phone_symmap))
_logger.info(str(train_dataset.spkr_symmap)) _logger.info(str(train_dataset.spkr_symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.") _logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.") _logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.") _logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
"""
_logger.info(f"#durations (train): {str(train_dataset.durations)}.")
_logger.info(f"#durations (val): {str(val_dataset.durations)}.")
_logger.info(f"#durations (subtrain): {str(subtrain_dataset.durations)}.")
"""
_logger.info(f"#duration (train): {str(train_dataset.duration)}.") _logger.info(f"#duration (train): {str(train_dataset.duration)}.")
_logger.info(f"#duration (val): {str(val_dataset.duration)}.") _logger.info(f"#duration (val): {str(val_dataset.duration)}.")
@ -648,8 +551,52 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl return train_dl, subtrain_dl, val_dl
# parse dataset into better to sample metadata
def create_dataset_metadata():
cfg.dataset.validate = False
cfg.dataset.use_hdf5 = False
paths_by_spkr_name = {}
paths_by_spkr_name |= _load_paths(cfg.dataset.training, "training")
paths_by_spkr_name |= _load_paths(cfg.dataset.validation, "validation")
paths_by_spkr_name |= _load_paths(cfg.dataset.noise, "noise")
paths = list(itertools.chain.from_iterable(paths_by_spkr_name.values()))
metadata = {}
for path in tqdm(paths, desc="Parsing paths"):
speaker = cfg.get_spkr(path)
if speaker not in metadata:
metadata[speaker] = {}
if cfg.dataset.use_hdf5:
phones = cfg.hdf5[_get_hdf5_path(path)].attrs['phonemes']
duration = cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
else:
phns_path = _get_phone_path(path)
qnts_path = _get_quant_path(path)
phones = len(_get_phones(phns_path)) if phns_path.exists() else 0
duration = _load_quants(qnts_path).shape[0] / 75 if qnts_path.exists() else 0
metadata[speaker][path.name.split(".")[0]] = {
"phones": phones,
"duration": duration
}
for speaker, paths in tqdm(paths_by_spkr_name.items(), desc="Writing metadata"):
if len(paths) == 0:
continue
with open(paths[0].parent / "metadata.json", "w", encoding="utf-8") as f:
f.write( json.dumps( metadata[speaker] ) )
with open(cfg.relpath / "metadata.json", "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# parse yaml to create an hdf5 file # parse yaml to create an hdf5 file
def create_dataset_hdf5(): def create_dataset_hdf5( skip_existing=True ):
cfg.dataset.use_hdf5 = True cfg.dataset.use_hdf5 = True
cfg.load_hdf5(write=True) cfg.load_hdf5(write=True)
@ -658,11 +605,11 @@ def create_dataset_hdf5():
root = cfg.cfg_path root = cfg.cfg_path
hf = cfg.hdf5 hf = cfg.hdf5
def add( dir, type="training", audios=True, texts=True ):
dir = "./" + str(dir)
name = dir.replace(root, "")
print( str(dir), name ) def add( dir, type="training", audios=True, texts=True ):
name = "./" + str(dir)
name = name .replace(root, "")
metadata = {}
if not os.path.isdir(f'{root}/{name}/'): if not os.path.isdir(f'{root}/{name}/'):
return return
@ -680,35 +627,53 @@ def create_dataset_hdf5():
key = f'{type}/{name}/{id}' key = f'{type}/{name}/{id}'
if key in hf: if key in hf:
# print("Skipping existing entry:", key) if skip_existing:
continue continue
del hf[key]
group = hf.create_group(key) group = hf.create_group(key)
group.attrs['id'] = id
group.attrs['type'] = type
group.attrs['speaker'] = name
metadata[id] = {}
# audio # audio
if audios: if audios:
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t() qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
if "audio" in group:
del group["audio"]
group.create_dataset('audio', data=qnt.numpy(), compression='lzf') group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
group.attrs['duration'] = qnt.shape[0] / 75
metadata[id]["duration"] = qnt.shape[0] / 75
else:
group.attrs['duration'] = 0
metadata[id]["duration"] = 0
# text # text
if texts: if texts:
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f: with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") as f:
content = f.read() content = f.read().split(" ")
split = content.split(" ") phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
for s in set(phones): for s in set(phones):
if s not in symmap: if s not in symmap:
symmap[s] = len(symmap.keys()) symmap[s] = len(symmap.keys())
phn = [ symmap[s] for s in phones ] phn = [ symmap[s] for s in phones ]
if "text" in group:
del group["text"]
group.create_dataset('text', data=phn, compression='lzf', chunks=True) group.create_dataset('text', data=phn, compression='lzf', chunks=True)
# metadata
group.attrs['id'] = id
group.attrs['type'] = type
group.attrs['speaker'] = name
group.attrs['duration'] = qnt.shape[0] / 75
group.attrs['phonemes'] = len(phn) group.attrs['phonemes'] = len(phn)
metadata[id]["phones"] = len(phn)
else:
group.attrs['phonemes'] = 0
metadata[id]["phones"] = 0
with open(dir / "metadata.json", "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# training # training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"): for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
@ -723,10 +688,9 @@ def create_dataset_hdf5():
add( data_dir, type="noise", texts=False ) add( data_dir, type="noise", texts=False )
# write symmap # write symmap
try: if "symmap" in hf:
hf.create_dataset('symmap', data=json.dumps(symmap)) del hf['symmap']
except Exception as e: hf.create_dataset('symmap', data=json.dumps(symmap))
pass
hf.close() hf.close()
@ -742,8 +706,16 @@ if __name__ == "__main__":
cfg.dataset.workers = 1 cfg.dataset.workers = 1
class LoggerOveride:
def info(self, *args):
print(*args)
_logger = LoggerOveride()
if args.action == "hdf5": if args.action == "hdf5":
create_dataset_hdf5() create_dataset_hdf5()
elif args.action == "metadata":
create_dataset_metadata()
elif args.action == "sample": elif args.action == "sample":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
@ -758,6 +730,7 @@ if __name__ == "__main__":
del v[i]['proms'] del v[i]['proms']
del v[i]['resps'] del v[i]['resps']
print(f'{k}:', v) print(f'{k}:', v)
elif args.action == "tasks": elif args.action == "tasks":
index = 0 index = 0
cfg.dataset.tasks_list = args.tasks.split(",") cfg.dataset.tasks_list = args.tasks.split(",")

View File

@ -33,10 +33,13 @@ def _get_backend( language="en-us", backend="espeak" ):
return phonemizer return phonemizer
def encode(text: str, language="en-us", backend="espeak") -> list[str]: def encode(text: str, language="en-us", backend="auto") -> list[str]:
if language == "en": if language == "en":
language = "en-us" language = "en-us"
if not backend or backend == "auto":
backend = "espeak" # if language[:2] != "en" else "festival"
text = [ text ] text = [ text ]
backend = _get_backend(language=language, backend=backend) backend = _get_backend(language=language, backend=backend)
@ -63,16 +66,13 @@ def main():
args = parser.parse_args() args = parser.parse_args()
paths = list(args.folder.rglob(f"*{args.suffix}")) paths = list(args.folder.rglob(f"*{args.suffix}"))
random.shuffle(paths)
for path in tqdm(paths): for path in tqdm(paths):
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt") phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
if phone_path.exists(): if phone_path.exists():
continue continue
graphs = _get_graphs(path) phones = encode(open(path, "r", encoding="utf-8").read())
phones = encode(graphs) open(phone_path, "w", encoding="utf-8").write(" ".join(phones))
with open(phone_path, "w") as f:
f.write(" ".join(phones))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -213,7 +213,7 @@ def repeat_extend_audio( qnt, target ):
pieces.append(qnt) pieces.append(qnt)
length += qnt.shape[0] length += qnt.shape[0]
return trim_random(torch.cat(pieces), target) return trim(torch.cat(pieces), target)
# merges two quantized audios together # merges two quantized audios together
# I don't know if this works # I don't know if this works

View File

@ -2,6 +2,7 @@ from ..config import cfg
from .base import Base, list_to_tensor, Categorical from .base import Base, list_to_tensor, Categorical
import torch import torch
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
@ -62,7 +63,7 @@ class AR(Base):
max_steps: int = 1000, max_steps: int = 1000,
sampling_temperature: float = 1.0, sampling_temperature: float = 1.0,
naive: bool = True, naive: bool = False,
): ):
if resps_list is not None: if resps_list is not None:
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
@ -83,30 +84,104 @@ class AR(Base):
] ]
stopped = torch.zeros(len(text_list), device=device).bool() stopped = torch.zeros(len(text_list), device=device).bool()
chunk_size = 1 # don't really know what to do about this desu chunk_size = self.causal_chunk_size # don't really know what to do about this desu
state = None state = None
start = 0 start = 0
for n in trange(max_steps // chunk_size): if naive:
# get next in sequence for n in trange(max_steps // max(1, chunk_size)):
# get next in sequence
r, state = super().forward( r, state = super().forward(
text_list, text_list,
proms_list, proms_list,
self._unsqueeze_list(resps_list), self._unsqueeze_list(resps_list),
sampling_temperature=sampling_temperature, sampling_temperature=sampling_temperature,
state=state if not naive else None, state=state # if not naive else None,
)
# append outputted token
if self.causal_chunk_size > 0:
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri])
else:
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri[None]])
# stop token found
stopped |= r == self.stop_token
if stopped.all().item():
break
# to-do: make it work
# it seems anything that isn't a one-at-a-time sequence does not work, despite generating STOP tokens.
else:
resps_list: list[Tensor] = [
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
]
test_list: list[Tensor] = [
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
]
batch_size = len(text_list)
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(self._unsqueeze_list(resps_list)),
sep=self.sep,
) )
x, m = list_to_tensor(x_list)
device = x.device
if state is None:
state = {}
# pre-fill KV cache
for n in trange(x.shape[1]):
xs = x[:, n:(n + 1), :]
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
r = self.classifier(r) * m
logits = torch.stack([hi[-1] for hi in r])
r = Categorical(logits=logits / sampling_temperature).sample()
for i, ri in enumerate(r):
test_list[i] = torch.cat([test_list[i], ri[None]])
# append outputted token # append outputted token
for i, ri in enumerate(r): for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri[None]]) resps_list[i] = torch.cat([resps_list[i], ri[None]])
# stop token found start = x.shape[1]
stopped |= r == self.stop_token for n in trange(max_steps // max(1, chunk_size)):
if stopped.all().item(): x_list = self._samplewise_merge_tensors(
break self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(self._unsqueeze_list(resps_list)),
sep=self.sep,
)
x, m = list_to_tensor(x_list)
xs = x[:, start+n:start+(n+1), :]
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
r = self.classifier(r) * m
logits = torch.stack([hi[-1] for hi in r])
r = Categorical(logits=logits / sampling_temperature).sample()
# append outputted token
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri[None]])
# stop token found
stopped |= r == self.stop_token
if stopped.all().item():
break
pruned = [self._prune(r) for r in resps_list] pruned = [self._prune(r) for r in resps_list]
return pruned return pruned

View File

@ -149,6 +149,8 @@ class Base(nn.Module):
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
self.sep = nn.Parameter(torch.randn(d_model)) self.sep = nn.Parameter(torch.randn(d_model))
self.causal_chunk_size = 0 # 64 if self.causal else 1
if self.arch_type == "transformer": if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model) self.sin_emb = SinusoidalEmbedding(d_model)
@ -171,8 +173,8 @@ class Base(nn.Module):
dropout=p_dropout, dropout=p_dropout,
checkpoint_activations=True, checkpoint_activations=True,
chunkwise_recurrent=False, # self.causal, chunkwise_recurrent=self.causal and self.causal_chunk_size > 0,
recurrent_chunkwise_size=64, recurrent_chunkwise_size=self.causal_chunk_size,
no_output_layer=True, no_output_layer=True,
decoder_normalize_before=True, decoder_normalize_before=True,
)) ))
@ -358,6 +360,10 @@ class Base(nn.Module):
elif return_all_resp: elif return_all_resp:
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))] logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ] ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
# return the last chunkwise piece
elif self.causal_chunk_size > 0:
logits = [hi[-self.causal_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
# return just the last code # return just the last code
else: else:
logits = torch.stack([hi[-1] for hi in h_list]) logits = torch.stack([hi[-1] for hi in h_list])