overhauled dataloading code to be marginally faster, mostly cleaned up, and can leverage a metadata json to help things out

This commit is contained in:
mrq 2023-08-26 19:53:23 -05:00
parent 7b3be3d7bf
commit 78378ed1ce
6 changed files with 321 additions and 254 deletions

View File

@ -27,6 +27,10 @@ class _Config:
def relpath(self):
return Path(self.cfg_path)
@property
def cache_dir(self):
return self.relpath / ".cache"
@property
def ckpt_dir(self):
return self.relpath / "ckpt"
@ -119,6 +123,7 @@ class Dataset:
hdf5_name: str = "data.h5"
use_hdf5: bool = False
use_metadata: bool = False
hdf5_flag: str = "a"
validate: bool = True
workers: int = 8
@ -135,6 +140,19 @@ class Dataset:
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
@property
def min_phones(self):
return self.phones_range[0]
@property
def max_phones(self):
return self.phones_range[1]
@property
def min_duration(self):
return self.duration_range[0]
@property
def max_duration(self):
return self.duration_range[1]
@dataclass()
class Model:
name: str = ""
@ -393,7 +411,7 @@ class Trainer:
weight_dtype: str = "float16"
backend: str = "deepspeed" if not sys.platform.startswith("win") else "local"
backend: str = "local" # "deepspeed" if not sys.platform.startswith("win") else "local"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
@ -453,10 +471,6 @@ class Config(_Config):
def get_spkr(self):
return eval(self.dataset.speaker_name_getter)
@property
def cache_dir(self):
return ".cache" / self.relpath
@cached_property
def diskcache(self):
if self.cfg_path is not None and self.dataset.cache:
@ -501,11 +515,10 @@ try:
if cfg.dataset.use_hdf5:
cfg.load_hdf5()
if not cfg.dataset.use_hdf5:
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
except Exception as e:
pass

View File

@ -8,6 +8,7 @@ import numpy as np
import os
import random
import torch
import itertools
from .config import cfg
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
@ -31,7 +32,7 @@ def get_phone_symmap():
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '': 179, '': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, '': 184, ':ˈ': 185}
return symmap
def get_task_symmap():
@ -51,24 +52,89 @@ def get_task_symmap():
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_hdf5_path(path):
path = str(path)
if path[:2] != "./":
path = f'./{path}'
return path.replace(cfg.cfg_path, "")
def _get_quant_path(path):
return _replace_file_extension(path, ".qnt.pt")
def _get_phone_path(path):
return _replace_file_extension(path, ".phn.txt")
def _load_paths(dataset, type="training"):
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
"""
def _load_paths_from_hdf5(dataset, type="training"):
return { cfg.get_spkr( data_dir / "dummy" ): _get_hdf5_paths( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
def _load_paths_from_disk(dataset, type="training"):
return { cfg.get_spkr( data_dir / "dummy" ): _get_paths_of_extensions( data_dir, ".qnt.pt", validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
"""
def _load_paths_from_metadata(data_dir, type="training", validate=False):
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
def _validate( entry ):
phones = entry['phones']
duration = entry['duration']
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
metadata_path = data_dir / "metadata.json"
if not cfg.dataset.use_metadata or not metadata_path.exists():
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
speaker = cfg.get_spkr( data_dir / "dummy" )
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
def key( dir, id ):
if not cfg.dataset.use_hdf5:
return data_dir / id
return f"/{type}{_get_hdf5_path(data_dir)}/{id}"
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
def _get_hdf5_path(path):
path = str(path)
if path[:2] != "./":
path = f'./{path}'
return path.replace(cfg.cfg_path, "")
def _get_hdf5_paths( data_dir, type="training", validate=False ):
data_dir = str(data_dir)
def _validate(child):
phones = child.attrs['phonemes']
duration = child.attrs['duration']
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
key = f"/{type}{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ):
if isinstance(path, str):
path = Path(path)
def _validate(path):
if "".join(path.suffixes) not in extensions:
return False
if not _get_phone_path(path).exists() or not _get_quant_path(path).exists():
return False
if not validate:
return True
# to-do: find an easy way to determine size from pickled quants without loading
# to-do: find a consistent way to derive phoneme count from filesize (probably can't due to utf-8)
phones = len(_get_phones(_get_phone_path(path))) # _get_phone_path(path).stat().st_size // 2 + 1
return cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
def _load_quants(path) -> Tensor:
return torch.load(path)[0][:, :].t().to(torch.int16)
return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
@cache
def _get_phones(path, language="en"):
content = open(_get_phone_path(path), "r", encoding="utf8").read().split(" ")
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
def _interleaved_reorder(l, fn):
@ -81,114 +147,61 @@ def _interleaved_reorder(l, fn):
if value is not None:
yield value
@cache
def _validate(path, min_phones, max_phones, min_duration, max_duration):
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if key not in cfg.hdf5:
return False
phones = cfg.hdf5[key].attrs['phonemes']
duration = cfg.hdf5[key].attrs['duration']
if phones < min_phones or phones > max_phones:
return False
if duration < min_duration or duration > max_duration:
return False
return True
if not os.path.exists(_get_phone_path(path)) or not os.path.exists(_get_quant_path(path)):
return False
phones = _get_phones(path)
unique_phones = list(set(phones))
if len(unique_phones) == 0:
return False
if len(unique_phones) == 1 and unique_phones[0] == " ":
return False
if len(phones) < min_phones or len(phones) > max_phones:
return False
return True
class Dataset(_Dataset):
def __init__(
self,
paths,
phone_symmap=None,
training=False,
extra_paths_by_spkr_name: dict[str, list] = {},
):
super().__init__()
self._head = None
self.min_phones = cfg.dataset.phones_range[0]
self.max_phones = cfg.dataset.phones_range[1]
self.min_duration = cfg.dataset.duration_range[0]
self.max_duration = cfg.dataset.duration_range[1]
self.sampler = None
if cfg.dataset.validate:
self.paths = [
path for path in paths if _validate(path, self.min_phones, self.max_phones, self.min_duration, self.max_duration)
]
else:
self.paths = paths
self.paths = []
self.training = training
self.dataset_type = "training" if self.training else "validation"
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
if cfg.dataset.sample_type == "path":
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
self.phone_symmap = phone_symmap or self._get_phone_symmap()
self.spkr_symmap = self._get_spkr_symmap()
self.task_symmap = self._get_task_symmap()
self.training = training
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
if cfg.dataset.validate:
self.paths = [
p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
]
if cfg.dataset.sample_type == "path":
self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)]
if len(self.paths) == 0 and training:
raise ValueError("No valid path is found for training.")
# would be a better cost saving if we could fetch the duration during the validation pass but oh well
self.duration = 0
self.durations = {}
if cfg.dataset.use_hdf5:
for path in self.paths:
key = _get_hdf5_path(path)
spkr_name = cfg.get_spkr(path)
spkr_id = self.spkr_symmap[spkr_name]
duration = cfg.hdf5[key].attrs['duration']
self.duration += duration
if spkr_id not in self.durations:
self.durations[spkr_id] = duration
else:
self.durations[spkr_id] += duration
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
ret = defaultdict(list)
for path in self.paths:
ret[cfg.get_spkr(path)].append(path)
for k, v in extra_paths_by_spkr_name.items():
ret[k].extend(v)
return {**ret}
self.duration += cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
@cached_property
def phones(self):
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
def get_speaker(self, path):
if isinstance(path, str):
path = Path(path)
res = cfg.get_spkr(path)
return res
@cached_property
def spkrs(self):
return sorted({cfg.get_spkr(path) for path in self.paths})
return sorted({self.get_speaker(path) for path in self.paths})
@cached_property
def tasks(self):
@ -209,13 +222,10 @@ class Dataset(_Dataset):
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
def sample_noise(self):
paths = []
for data_dir in cfg.dataset.noise:
paths.extend(data_dir.rglob("*.qnt.pt"))
path = random.choice(paths)
path = random.choice(self.noise_paths)
if False and cfg.dataset.use_hdf5:
key = f'/noise/{_get_hdf5_path(path)}'
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
qnt = _load_quants(path)
@ -275,7 +285,7 @@ class Dataset(_Dataset):
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
else:
path = self.paths[index]
spkr_name = cfg.get_spkr(path)
spkr_name = self.get_speaker(path)
spkr_id = self.spkr_symmap[spkr_name]
if cfg.dataset.use_hdf5:
@ -507,110 +517,10 @@ def _create_dataloader(dataset, training):
sampler=sampler,
)
def _load_dataset_paths():
hf = cfg.hdf5
paths = {
"training": [],
"validation": [],
}
datasets = {
"training": [],
"validation": [],
}
def get_paths( data_dir, type="training" ):
key = f"/{type}{_get_hdf5_path(data_dir)}"
if key not in cfg.hdf5:
return
paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ])
for data_dir in cfg.dataset.training:
get_paths( data_dir, "training" )
for data_dir in cfg.dataset.validation:
get_paths( data_dir, "validation" )
for _, type in enumerate(paths):
dirs = paths[type]
if len(dirs) == 0:
continue
dirs = [ Path(p) for p in dirs ]
pairs = sorted([(cfg.get_spkr(p), p) for p in dirs])
for _, group in groupby(pairs, lambda pair: pair[0]):
shuffled = sorted([p for _, p in group])
random.seed(0)
random.shuffle(shuffled)
datasets[type].extend(shuffled)
return datasets["training"], datasets["validation"]
# to-do: mirror the hdf5-based load function
def _load_train_val_paths():
paths = []
train_paths = []
val_paths = []
for data_dir in cfg.dataset.training:
paths.extend(data_dir.rglob("*.qnt.pt"))
if len(paths) > 0:
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
del paths
for _, group in groupby(pairs, lambda pair: pair[0]):
paths = sorted([p for _, p in group])
random.seed(0)
random.shuffle(paths)
train_paths.extend(paths)
for data_dir in cfg.dataset.validation:
paths.extend(data_dir.rglob("*.qnt.pt"))
if len(paths) > 0:
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
del paths
for _, group in groupby(pairs, lambda pair: pair[0]):
paths = sorted([p for _, p in group])
random.seed(0)
random.shuffle(paths)
val_paths.extend(paths)
train_paths, val_paths = map(sorted, [train_paths, val_paths])
if len(train_paths) == 0:
raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.")
# to-do: allow setting aside a fixed portion of the training dataset for validation
# something like the last X percent of each speaker is set aside
if len(val_paths) == 0:
val_paths = [ train_paths[0] ]
return train_paths, val_paths
@cfg.diskcache()
def create_datasets():
train_paths, val_paths = _load_dataset_paths() if cfg.dataset.use_hdf5 else _load_train_val_paths()
train_dataset = Dataset(
train_paths,
training=True,
)
val_dataset = Dataset(
val_paths,
train_dataset.phone_symmap,
#train_dataset.spkr_symmap,
#extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
)
val_dataset.head_(cfg.evaluation.size)
train_dataset = Dataset( training=True )
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
return train_dataset, val_dataset
@ -628,17 +538,10 @@ def create_train_val_dataloader():
_logger.info(str(train_dataset.phone_symmap))
_logger.info(str(train_dataset.spkr_symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
"""
_logger.info(f"#durations (train): {str(train_dataset.durations)}.")
_logger.info(f"#durations (val): {str(val_dataset.durations)}.")
_logger.info(f"#durations (subtrain): {str(subtrain_dataset.durations)}.")
"""
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
@ -648,8 +551,52 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl
# parse dataset into better to sample metadata
def create_dataset_metadata():
cfg.dataset.validate = False
cfg.dataset.use_hdf5 = False
paths_by_spkr_name = {}
paths_by_spkr_name |= _load_paths(cfg.dataset.training, "training")
paths_by_spkr_name |= _load_paths(cfg.dataset.validation, "validation")
paths_by_spkr_name |= _load_paths(cfg.dataset.noise, "noise")
paths = list(itertools.chain.from_iterable(paths_by_spkr_name.values()))
metadata = {}
for path in tqdm(paths, desc="Parsing paths"):
speaker = cfg.get_spkr(path)
if speaker not in metadata:
metadata[speaker] = {}
if cfg.dataset.use_hdf5:
phones = cfg.hdf5[_get_hdf5_path(path)].attrs['phonemes']
duration = cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
else:
phns_path = _get_phone_path(path)
qnts_path = _get_quant_path(path)
phones = len(_get_phones(phns_path)) if phns_path.exists() else 0
duration = _load_quants(qnts_path).shape[0] / 75 if qnts_path.exists() else 0
metadata[speaker][path.name.split(".")[0]] = {
"phones": phones,
"duration": duration
}
for speaker, paths in tqdm(paths_by_spkr_name.items(), desc="Writing metadata"):
if len(paths) == 0:
continue
with open(paths[0].parent / "metadata.json", "w", encoding="utf-8") as f:
f.write( json.dumps( metadata[speaker] ) )
with open(cfg.relpath / "metadata.json", "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# parse yaml to create an hdf5 file
def create_dataset_hdf5():
def create_dataset_hdf5( skip_existing=True ):
cfg.dataset.use_hdf5 = True
cfg.load_hdf5(write=True)
@ -658,11 +605,11 @@ def create_dataset_hdf5():
root = cfg.cfg_path
hf = cfg.hdf5
def add( dir, type="training", audios=True, texts=True ):
dir = "./" + str(dir)
name = dir.replace(root, "")
print( str(dir), name )
def add( dir, type="training", audios=True, texts=True ):
name = "./" + str(dir)
name = name .replace(root, "")
metadata = {}
if not os.path.isdir(f'{root}/{name}/'):
return
@ -680,35 +627,53 @@ def create_dataset_hdf5():
key = f'{type}/{name}/{id}'
if key in hf:
# print("Skipping existing entry:", key)
continue
if skip_existing:
continue
del hf[key]
group = hf.create_group(key)
group.attrs['id'] = id
group.attrs['type'] = type
group.attrs['speaker'] = name
metadata[id] = {}
# audio
if audios:
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
if "audio" in group:
del group["audio"]
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
group.attrs['duration'] = qnt.shape[0] / 75
metadata[id]["duration"] = qnt.shape[0] / 75
else:
group.attrs['duration'] = 0
metadata[id]["duration"] = 0
# text
if texts:
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
content = f.read()
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") as f:
content = f.read().split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
for s in set(phones):
if s not in symmap:
symmap[s] = len(symmap.keys())
phn = [ symmap[s] for s in phones ]
if "text" in group:
del group["text"]
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
# metadata
group.attrs['id'] = id
group.attrs['type'] = type
group.attrs['speaker'] = name
group.attrs['duration'] = qnt.shape[0] / 75
group.attrs['phonemes'] = len(phn)
metadata[id]["phones"] = len(phn)
else:
group.attrs['phonemes'] = 0
metadata[id]["phones"] = 0
with open(dir / "metadata.json", "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
@ -723,10 +688,9 @@ def create_dataset_hdf5():
add( data_dir, type="noise", texts=False )
# write symmap
try:
hf.create_dataset('symmap', data=json.dumps(symmap))
except Exception as e:
pass
if "symmap" in hf:
del hf['symmap']
hf.create_dataset('symmap', data=json.dumps(symmap))
hf.close()
@ -742,8 +706,16 @@ if __name__ == "__main__":
cfg.dataset.workers = 1
class LoggerOveride:
def info(self, *args):
print(*args)
_logger = LoggerOveride()
if args.action == "hdf5":
create_dataset_hdf5()
elif args.action == "metadata":
create_dataset_metadata()
elif args.action == "sample":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
@ -758,6 +730,7 @@ if __name__ == "__main__":
del v[i]['proms']
del v[i]['resps']
print(f'{k}:', v)
elif args.action == "tasks":
index = 0
cfg.dataset.tasks_list = args.tasks.split(",")

View File

@ -33,10 +33,13 @@ def _get_backend( language="en-us", backend="espeak" ):
return phonemizer
def encode(text: str, language="en-us", backend="espeak") -> list[str]:
def encode(text: str, language="en-us", backend="auto") -> list[str]:
if language == "en":
language = "en-us"
if not backend or backend == "auto":
backend = "espeak" # if language[:2] != "en" else "festival"
text = [ text ]
backend = _get_backend(language=language, backend=backend)
@ -63,16 +66,13 @@ def main():
args = parser.parse_args()
paths = list(args.folder.rglob(f"*{args.suffix}"))
random.shuffle(paths)
for path in tqdm(paths):
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
if phone_path.exists():
continue
graphs = _get_graphs(path)
phones = encode(graphs)
with open(phone_path, "w") as f:
f.write(" ".join(phones))
phones = encode(open(path, "r", encoding="utf-8").read())
open(phone_path, "w", encoding="utf-8").write(" ".join(phones))
if __name__ == "__main__":

View File

@ -213,7 +213,7 @@ def repeat_extend_audio( qnt, target ):
pieces.append(qnt)
length += qnt.shape[0]
return trim_random(torch.cat(pieces), target)
return trim(torch.cat(pieces), target)
# merges two quantized audios together
# I don't know if this works

View File

@ -2,6 +2,7 @@ from ..config import cfg
from .base import Base, list_to_tensor, Categorical
import torch
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange
from torch import Tensor
@ -62,7 +63,7 @@ class AR(Base):
max_steps: int = 1000,
sampling_temperature: float = 1.0,
naive: bool = True,
naive: bool = False,
):
if resps_list is not None:
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
@ -83,30 +84,104 @@ class AR(Base):
]
stopped = torch.zeros(len(text_list), device=device).bool()
chunk_size = 1 # don't really know what to do about this desu
chunk_size = self.causal_chunk_size # don't really know what to do about this desu
state = None
start = 0
for n in trange(max_steps // chunk_size):
# get next in sequence
if naive:
for n in trange(max_steps // max(1, chunk_size)):
# get next in sequence
r, state = super().forward(
text_list,
proms_list,
self._unsqueeze_list(resps_list),
sampling_temperature=sampling_temperature,
state=state if not naive else None,
r, state = super().forward(
text_list,
proms_list,
self._unsqueeze_list(resps_list),
sampling_temperature=sampling_temperature,
state=state # if not naive else None,
)
# append outputted token
if self.causal_chunk_size > 0:
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri])
else:
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri[None]])
# stop token found
stopped |= r == self.stop_token
if stopped.all().item():
break
# to-do: make it work
# it seems anything that isn't a one-at-a-time sequence does not work, despite generating STOP tokens.
else:
resps_list: list[Tensor] = [
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
]
test_list: list[Tensor] = [
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
]
batch_size = len(text_list)
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(self._unsqueeze_list(resps_list)),
sep=self.sep,
)
x, m = list_to_tensor(x_list)
device = x.device
if state is None:
state = {}
# pre-fill KV cache
for n in trange(x.shape[1]):
xs = x[:, n:(n + 1), :]
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
r = self.classifier(r) * m
logits = torch.stack([hi[-1] for hi in r])
r = Categorical(logits=logits / sampling_temperature).sample()
for i, ri in enumerate(r):
test_list[i] = torch.cat([test_list[i], ri[None]])
# append outputted token
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri[None]])
# stop token found
stopped |= r == self.stop_token
if stopped.all().item():
break
start = x.shape[1]
for n in trange(max_steps // max(1, chunk_size)):
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(self._unsqueeze_list(resps_list)),
sep=self.sep,
)
x, m = list_to_tensor(x_list)
xs = x[:, start+n:start+(n+1), :]
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
r = self.classifier(r) * m
logits = torch.stack([hi[-1] for hi in r])
r = Categorical(logits=logits / sampling_temperature).sample()
# append outputted token
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri[None]])
# stop token found
stopped |= r == self.stop_token
if stopped.all().item():
break
pruned = [self._prune(r) for r in resps_list]
return pruned

View File

@ -149,6 +149,8 @@ class Base(nn.Module):
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
self.sep = nn.Parameter(torch.randn(d_model))
self.causal_chunk_size = 0 # 64 if self.causal else 1
if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model)
@ -171,8 +173,8 @@ class Base(nn.Module):
dropout=p_dropout,
checkpoint_activations=True,
chunkwise_recurrent=False, # self.causal,
recurrent_chunkwise_size=64,
chunkwise_recurrent=self.causal and self.causal_chunk_size > 0,
recurrent_chunkwise_size=self.causal_chunk_size,
no_output_layer=True,
decoder_normalize_before=True,
))
@ -358,6 +360,10 @@ class Base(nn.Module):
elif return_all_resp:
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
# return the last chunkwise piece
elif self.causal_chunk_size > 0:
logits = [hi[-self.causal_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
# return just the last code
else:
logits = torch.stack([hi[-1] for hi in h_list])