diff --git a/vall_e/config.py b/vall_e/config.py index 0bfb0e9..16febda 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -27,6 +27,10 @@ class _Config: def relpath(self): return Path(self.cfg_path) + @property + def cache_dir(self): + return self.relpath / ".cache" + @property def ckpt_dir(self): return self.relpath / "ckpt" @@ -119,6 +123,7 @@ class Dataset: hdf5_name: str = "data.h5" use_hdf5: bool = False + use_metadata: bool = False hdf5_flag: str = "a" validate: bool = True workers: int = 8 @@ -135,6 +140,19 @@ class Dataset: tasks_list: list[str] = field(default_factory=lambda: ["tts"]) + @property + def min_phones(self): + return self.phones_range[0] + @property + def max_phones(self): + return self.phones_range[1] + @property + def min_duration(self): + return self.duration_range[0] + @property + def max_duration(self): + return self.duration_range[1] + @dataclass() class Model: name: str = "" @@ -393,7 +411,7 @@ class Trainer: weight_dtype: str = "float16" - backend: str = "deepspeed" if not sys.platform.startswith("win") else "local" + backend: str = "local" # "deepspeed" if not sys.platform.startswith("win") else "local" deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) @@ -453,10 +471,6 @@ class Config(_Config): def get_spkr(self): return eval(self.dataset.speaker_name_getter) - @property - def cache_dir(self): - return ".cache" / self.relpath - @cached_property def diskcache(self): if self.cfg_path is not None and self.dataset.cache: @@ -501,11 +515,10 @@ try: if cfg.dataset.use_hdf5: cfg.load_hdf5() - if not cfg.dataset.use_hdf5: - cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ] - cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ] - + cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ] + cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ] cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ] + except Exception as e: pass diff --git a/vall_e/data.py b/vall_e/data.py index 6b85872..b062d60 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -8,6 +8,7 @@ import numpy as np import os import random import torch +import itertools from .config import cfg from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file @@ -31,7 +32,7 @@ def get_phone_symmap(): if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5: return json.loads( cfg.hdf5['symmap'].asstr()[()] ) - symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} + symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185} return symmap def get_task_symmap(): @@ -51,24 +52,89 @@ def get_task_symmap(): def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) -def _get_hdf5_path(path): - path = str(path) - if path[:2] != "./": - path = f'./{path}' - return path.replace(cfg.cfg_path, "") - def _get_quant_path(path): return _replace_file_extension(path, ".qnt.pt") def _get_phone_path(path): return _replace_file_extension(path, ".phn.txt") +def _load_paths(dataset, type="training"): + return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") } + +""" +def _load_paths_from_hdf5(dataset, type="training"): + return { cfg.get_spkr( data_dir / "dummy" ): _get_hdf5_paths( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") } + +def _load_paths_from_disk(dataset, type="training"): + return { cfg.get_spkr( data_dir / "dummy" ): _get_paths_of_extensions( data_dir, ".qnt.pt", validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") } +""" + +def _load_paths_from_metadata(data_dir, type="training", validate=False): + _fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions + + def _validate( entry ): + phones = entry['phones'] + duration = entry['duration'] + return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones + + metadata_path = data_dir / "metadata.json" + if not cfg.dataset.use_metadata or not metadata_path.exists(): + return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate ) + + speaker = cfg.get_spkr( data_dir / "dummy" ) + metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) + + def key( dir, id ): + if not cfg.dataset.use_hdf5: + return data_dir / id + + return f"/{type}{_get_hdf5_path(data_dir)}/{id}" + + return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ] + + +def _get_hdf5_path(path): + path = str(path) + if path[:2] != "./": + path = f'./{path}' + return path.replace(cfg.cfg_path, "") + +def _get_hdf5_paths( data_dir, type="training", validate=False ): + data_dir = str(data_dir) + + def _validate(child): + phones = child.attrs['phonemes'] + duration = child.attrs['duration'] + return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones + + key = f"/{type}{_get_hdf5_path(data_dir)}" + return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else [] + +def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ): + if isinstance(path, str): + path = Path(path) + + def _validate(path): + if "".join(path.suffixes) not in extensions: + return False + if not _get_phone_path(path).exists() or not _get_quant_path(path).exists(): + return False + if not validate: + return True + # to-do: find an easy way to determine size from pickled quants without loading + # to-do: find a consistent way to derive phoneme count from filesize (probably can't due to utf-8) + phones = len(_get_phones(_get_phone_path(path))) # _get_phone_path(path).stat().st_size // 2 + 1 + return cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones + + + return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else [] + def _load_quants(path) -> Tensor: - return torch.load(path)[0][:, :].t().to(torch.int16) + return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16) @cache def _get_phones(path, language="en"): - content = open(_get_phone_path(path), "r", encoding="utf8").read().split(" ") + content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ") return [""] + [ " " if not p else p for p in content ] + [""] def _interleaved_reorder(l, fn): @@ -81,114 +147,61 @@ def _interleaved_reorder(l, fn): if value is not None: yield value - -@cache -def _validate(path, min_phones, max_phones, min_duration, max_duration): - if cfg.dataset.use_hdf5: - key = _get_hdf5_path(path) - if key not in cfg.hdf5: - return False - - phones = cfg.hdf5[key].attrs['phonemes'] - duration = cfg.hdf5[key].attrs['duration'] - - if phones < min_phones or phones > max_phones: - return False - - if duration < min_duration or duration > max_duration: - return False - - return True - - if not os.path.exists(_get_phone_path(path)) or not os.path.exists(_get_quant_path(path)): - return False - - phones = _get_phones(path) - unique_phones = list(set(phones)) - - if len(unique_phones) == 0: - return False - if len(unique_phones) == 1 and unique_phones[0] == " ": - return False - if len(phones) < min_phones or len(phones) > max_phones: - return False - return True - class Dataset(_Dataset): def __init__( self, - paths, phone_symmap=None, training=False, extra_paths_by_spkr_name: dict[str, list] = {}, ): super().__init__() self._head = None - self.min_phones = cfg.dataset.phones_range[0] - self.max_phones = cfg.dataset.phones_range[1] - self.min_duration = cfg.dataset.duration_range[0] - self.max_duration = cfg.dataset.duration_range[1] self.sampler = None - if cfg.dataset.validate: - self.paths = [ - path for path in paths if _validate(path, self.min_phones, self.max_phones, self.min_duration, self.max_duration) - ] - else: - self.paths = paths + self.paths = [] + + self.training = training + self.dataset_type = "training" if self.training else "validation" + self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation + + self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type) + self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values())) + + if cfg.dataset.sample_type == "path": + self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] + + self.noise_paths = _load_paths(cfg.dataset.noise, "noise") + self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values())) self.phone_symmap = phone_symmap or self._get_phone_symmap() self.spkr_symmap = self._get_spkr_symmap() self.task_symmap = self._get_task_symmap() - self.training = training # assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8" self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16 - self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name) - - if cfg.dataset.validate: - self.paths = [ - p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1 - ] - - if cfg.dataset.sample_type == "path": - self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)] - if len(self.paths) == 0 and training: raise ValueError("No valid path is found for training.") + # would be a better cost saving if we could fetch the duration during the validation pass but oh well self.duration = 0 - self.durations = {} if cfg.dataset.use_hdf5: for path in self.paths: - key = _get_hdf5_path(path) - spkr_name = cfg.get_spkr(path) - spkr_id = self.spkr_symmap[spkr_name] - duration = cfg.hdf5[key].attrs['duration'] - - self.duration += duration - - if spkr_id not in self.durations: - self.durations[spkr_id] = duration - else: - self.durations[spkr_id] += duration - - def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]): - ret = defaultdict(list) - for path in self.paths: - ret[cfg.get_spkr(path)].append(path) - for k, v in extra_paths_by_spkr_name.items(): - ret[k].extend(v) - return {**ret} + self.duration += cfg.hdf5[_get_hdf5_path(path)].attrs['duration'] @cached_property def phones(self): return sorted(set().union(*[_get_phones(path) for path in self.paths])) + def get_speaker(self, path): + if isinstance(path, str): + path = Path(path) + res = cfg.get_spkr(path) + return res + @cached_property def spkrs(self): - return sorted({cfg.get_spkr(path) for path in self.paths}) + return sorted({self.get_speaker(path) for path in self.paths}) @cached_property def tasks(self): @@ -209,13 +222,10 @@ class Dataset(_Dataset): return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16) def sample_noise(self): - paths = [] - for data_dir in cfg.dataset.noise: - paths.extend(data_dir.rglob("*.qnt.pt")) - path = random.choice(paths) + path = random.choice(self.noise_paths) - if False and cfg.dataset.use_hdf5: - key = f'/noise/{_get_hdf5_path(path)}' + if cfg.dataset.use_hdf5: + key = _get_hdf5_path(path) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: qnt = _load_quants(path) @@ -275,7 +285,7 @@ class Dataset(_Dataset): path = random.choice([*set(self.paths_by_spkr_name[spkr_name])]) else: path = self.paths[index] - spkr_name = cfg.get_spkr(path) + spkr_name = self.get_speaker(path) spkr_id = self.spkr_symmap[spkr_name] if cfg.dataset.use_hdf5: @@ -507,110 +517,10 @@ def _create_dataloader(dataset, training): sampler=sampler, ) -def _load_dataset_paths(): - hf = cfg.hdf5 - paths = { - "training": [], - "validation": [], - } - - datasets = { - "training": [], - "validation": [], - } - - def get_paths( data_dir, type="training" ): - key = f"/{type}{_get_hdf5_path(data_dir)}" - if key not in cfg.hdf5: - return - - paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ]) - - for data_dir in cfg.dataset.training: - get_paths( data_dir, "training" ) - - for data_dir in cfg.dataset.validation: - get_paths( data_dir, "validation" ) - - for _, type in enumerate(paths): - dirs = paths[type] - - if len(dirs) == 0: - continue - - dirs = [ Path(p) for p in dirs ] - - pairs = sorted([(cfg.get_spkr(p), p) for p in dirs]) - for _, group in groupby(pairs, lambda pair: pair[0]): - shuffled = sorted([p for _, p in group]) - random.seed(0) - random.shuffle(shuffled) - - datasets[type].extend(shuffled) - - return datasets["training"], datasets["validation"] - -# to-do: mirror the hdf5-based load function -def _load_train_val_paths(): - paths = [] - train_paths = [] - val_paths = [] - - for data_dir in cfg.dataset.training: - paths.extend(data_dir.rglob("*.qnt.pt")) - - if len(paths) > 0: - pairs = sorted([(cfg.get_spkr(p), p) for p in paths]) - del paths - - for _, group in groupby(pairs, lambda pair: pair[0]): - paths = sorted([p for _, p in group]) - random.seed(0) - random.shuffle(paths) - train_paths.extend(paths) - - for data_dir in cfg.dataset.validation: - paths.extend(data_dir.rglob("*.qnt.pt")) - - if len(paths) > 0: - pairs = sorted([(cfg.get_spkr(p), p) for p in paths]) - del paths - - for _, group in groupby(pairs, lambda pair: pair[0]): - paths = sorted([p for _, p in group]) - random.seed(0) - random.shuffle(paths) - val_paths.extend(paths) - - train_paths, val_paths = map(sorted, [train_paths, val_paths]) - - if len(train_paths) == 0: - raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.") - - # to-do: allow setting aside a fixed portion of the training dataset for validation - # something like the last X percent of each speaker is set aside - if len(val_paths) == 0: - val_paths = [ train_paths[0] ] - - return train_paths, val_paths - @cfg.diskcache() def create_datasets(): - train_paths, val_paths = _load_dataset_paths() if cfg.dataset.use_hdf5 else _load_train_val_paths() - - train_dataset = Dataset( - train_paths, - training=True, - ) - - val_dataset = Dataset( - val_paths, - train_dataset.phone_symmap, - #train_dataset.spkr_symmap, - #extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name, - ) - - val_dataset.head_(cfg.evaluation.size) + train_dataset = Dataset( training=True ) + val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False ) return train_dataset, val_dataset @@ -628,17 +538,10 @@ def create_train_val_dataloader(): _logger.info(str(train_dataset.phone_symmap)) _logger.info(str(train_dataset.spkr_symmap)) - _logger.info(f"#samples (train): {len(train_dataset)}.") _logger.info(f"#samples (val): {len(val_dataset)}.") _logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.") - - """ - _logger.info(f"#durations (train): {str(train_dataset.durations)}.") - _logger.info(f"#durations (val): {str(val_dataset.durations)}.") - _logger.info(f"#durations (subtrain): {str(subtrain_dataset.durations)}.") - """ _logger.info(f"#duration (train): {str(train_dataset.duration)}.") _logger.info(f"#duration (val): {str(val_dataset.duration)}.") @@ -648,8 +551,52 @@ def create_train_val_dataloader(): return train_dl, subtrain_dl, val_dl +# parse dataset into better to sample metadata +def create_dataset_metadata(): + cfg.dataset.validate = False + cfg.dataset.use_hdf5 = False + + paths_by_spkr_name = {} + + paths_by_spkr_name |= _load_paths(cfg.dataset.training, "training") + paths_by_spkr_name |= _load_paths(cfg.dataset.validation, "validation") + paths_by_spkr_name |= _load_paths(cfg.dataset.noise, "noise") + + paths = list(itertools.chain.from_iterable(paths_by_spkr_name.values())) + + metadata = {} + for path in tqdm(paths, desc="Parsing paths"): + speaker = cfg.get_spkr(path) + if speaker not in metadata: + metadata[speaker] = {} + + if cfg.dataset.use_hdf5: + phones = cfg.hdf5[_get_hdf5_path(path)].attrs['phonemes'] + duration = cfg.hdf5[_get_hdf5_path(path)].attrs['duration'] + else: + phns_path = _get_phone_path(path) + qnts_path = _get_quant_path(path) + + phones = len(_get_phones(phns_path)) if phns_path.exists() else 0 + duration = _load_quants(qnts_path).shape[0] / 75 if qnts_path.exists() else 0 + + + metadata[speaker][path.name.split(".")[0]] = { + "phones": phones, + "duration": duration + } + + for speaker, paths in tqdm(paths_by_spkr_name.items(), desc="Writing metadata"): + if len(paths) == 0: + continue + with open(paths[0].parent / "metadata.json", "w", encoding="utf-8") as f: + f.write( json.dumps( metadata[speaker] ) ) + + with open(cfg.relpath / "metadata.json", "w", encoding="utf-8") as f: + f.write( json.dumps( metadata ) ) + # parse yaml to create an hdf5 file -def create_dataset_hdf5(): +def create_dataset_hdf5( skip_existing=True ): cfg.dataset.use_hdf5 = True cfg.load_hdf5(write=True) @@ -658,11 +605,11 @@ def create_dataset_hdf5(): root = cfg.cfg_path hf = cfg.hdf5 - def add( dir, type="training", audios=True, texts=True ): - dir = "./" + str(dir) - name = dir.replace(root, "") - print( str(dir), name ) + def add( dir, type="training", audios=True, texts=True ): + name = "./" + str(dir) + name = name .replace(root, "") + metadata = {} if not os.path.isdir(f'{root}/{name}/'): return @@ -680,35 +627,53 @@ def create_dataset_hdf5(): key = f'{type}/{name}/{id}' if key in hf: - # print("Skipping existing entry:", key) - continue + if skip_existing: + continue + del hf[key] group = hf.create_group(key) + group.attrs['id'] = id + group.attrs['type'] = type + group.attrs['speaker'] = name + + metadata[id] = {} # audio if audios: qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t() + + if "audio" in group: + del group["audio"] group.create_dataset('audio', data=qnt.numpy(), compression='lzf') + group.attrs['duration'] = qnt.shape[0] / 75 + metadata[id]["duration"] = qnt.shape[0] / 75 + else: + group.attrs['duration'] = 0 + metadata[id]["duration"] = 0 # text if texts: - with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f: - content = f.read() - split = content.split(" ") - phones = [f""] + [ " " if not p else p for p in split ] + [f""] + with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") as f: + content = f.read().split(" ") + phones = [f""] + [ " " if not p else p for p in content ] + [f""] for s in set(phones): if s not in symmap: symmap[s] = len(symmap.keys()) + phn = [ symmap[s] for s in phones ] + if "text" in group: + del group["text"] group.create_dataset('text', data=phn, compression='lzf', chunks=True) - - # metadata - group.attrs['id'] = id - group.attrs['type'] = type - group.attrs['speaker'] = name - group.attrs['duration'] = qnt.shape[0] / 75 group.attrs['phonemes'] = len(phn) + metadata[id]["phones"] = len(phn) + else: + group.attrs['phonemes'] = 0 + metadata[id]["phones"] = 0 + + with open(dir / "metadata.json", "w", encoding="utf-8") as f: + f.write( json.dumps( metadata ) ) + # training for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"): @@ -723,10 +688,9 @@ def create_dataset_hdf5(): add( data_dir, type="noise", texts=False ) # write symmap - try: - hf.create_dataset('symmap', data=json.dumps(symmap)) - except Exception as e: - pass + if "symmap" in hf: + del hf['symmap'] + hf.create_dataset('symmap', data=json.dumps(symmap)) hf.close() @@ -742,8 +706,16 @@ if __name__ == "__main__": cfg.dataset.workers = 1 + class LoggerOveride: + def info(self, *args): + print(*args) + + _logger = LoggerOveride() + if args.action == "hdf5": create_dataset_hdf5() + elif args.action == "metadata": + create_dataset_metadata() elif args.action == "sample": train_dl, subtrain_dl, val_dl = create_train_val_dataloader() @@ -758,6 +730,7 @@ if __name__ == "__main__": del v[i]['proms'] del v[i]['resps'] print(f'{k}:', v) + elif args.action == "tasks": index = 0 cfg.dataset.tasks_list = args.tasks.split(",") diff --git a/vall_e/emb/g2p.py b/vall_e/emb/g2p.py index a5ea3b6..3c64536 100755 --- a/vall_e/emb/g2p.py +++ b/vall_e/emb/g2p.py @@ -33,10 +33,13 @@ def _get_backend( language="en-us", backend="espeak" ): return phonemizer -def encode(text: str, language="en-us", backend="espeak") -> list[str]: +def encode(text: str, language="en-us", backend="auto") -> list[str]: if language == "en": language = "en-us" + if not backend or backend == "auto": + backend = "espeak" # if language[:2] != "en" else "festival" + text = [ text ] backend = _get_backend(language=language, backend=backend) @@ -63,16 +66,13 @@ def main(): args = parser.parse_args() paths = list(args.folder.rglob(f"*{args.suffix}")) - random.shuffle(paths) for path in tqdm(paths): phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt") if phone_path.exists(): continue - graphs = _get_graphs(path) - phones = encode(graphs) - with open(phone_path, "w") as f: - f.write(" ".join(phones)) + phones = encode(open(path, "r", encoding="utf-8").read()) + open(phone_path, "w", encoding="utf-8").write(" ".join(phones)) if __name__ == "__main__": diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 25bf618..03f9983 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -213,7 +213,7 @@ def repeat_extend_audio( qnt, target ): pieces.append(qnt) length += qnt.shape[0] - return trim_random(torch.cat(pieces), target) + return trim(torch.cat(pieces), target) # merges two quantized audios together # I don't know if this works diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 46975a1..0d370db 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -2,6 +2,7 @@ from ..config import cfg from .base import Base, list_to_tensor, Categorical import torch +from torch.nn.utils.rnn import pad_sequence from einops import rearrange from torch import Tensor @@ -62,7 +63,7 @@ class AR(Base): max_steps: int = 1000, sampling_temperature: float = 1.0, - naive: bool = True, + naive: bool = False, ): if resps_list is not None: resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels @@ -83,30 +84,104 @@ class AR(Base): ] stopped = torch.zeros(len(text_list), device=device).bool() - chunk_size = 1 # don't really know what to do about this desu + chunk_size = self.causal_chunk_size # don't really know what to do about this desu state = None start = 0 - for n in trange(max_steps // chunk_size): - # get next in sequence + if naive: + for n in trange(max_steps // max(1, chunk_size)): + # get next in sequence - r, state = super().forward( - text_list, - proms_list, - self._unsqueeze_list(resps_list), - sampling_temperature=sampling_temperature, - state=state if not naive else None, + r, state = super().forward( + text_list, + proms_list, + self._unsqueeze_list(resps_list), + sampling_temperature=sampling_temperature, + state=state # if not naive else None, + ) + + # append outputted token + if self.causal_chunk_size > 0: + for i, ri in enumerate(r): + resps_list[i] = torch.cat([resps_list[i], ri]) + else: + for i, ri in enumerate(r): + resps_list[i] = torch.cat([resps_list[i], ri[None]]) + + + # stop token found + stopped |= r == self.stop_token + if stopped.all().item(): + break + # to-do: make it work + # it seems anything that isn't a one-at-a-time sequence does not work, despite generating STOP tokens. + else: + resps_list: list[Tensor] = [ + torch.zeros(0, device=device).to(torch.int16) for _ in text_list + ] + + test_list: list[Tensor] = [ + torch.zeros(0, device=device).to(torch.int16) for _ in text_list + ] + + batch_size = len(text_list) + + x_list = self._samplewise_merge_tensors( + self.text_emb(text_list), + self.proms_emb(proms_list), + self.resps_emb(self._unsqueeze_list(resps_list)), + sep=self.sep, ) + x, m = list_to_tensor(x_list) + device = x.device + + if state is None: + state = {} + + # pre-fill KV cache + for n in trange(x.shape[1]): + xs = x[:, n:(n + 1), :] + r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True) + r = self.classifier(r) * m + + logits = torch.stack([hi[-1] for hi in r]) + r = Categorical(logits=logits / sampling_temperature).sample() + + for i, ri in enumerate(r): + test_list[i] = torch.cat([test_list[i], ri[None]]) + # append outputted token for i, ri in enumerate(r): resps_list[i] = torch.cat([resps_list[i], ri[None]]) - # stop token found - stopped |= r == self.stop_token - if stopped.all().item(): - break + start = x.shape[1] + for n in trange(max_steps // max(1, chunk_size)): + x_list = self._samplewise_merge_tensors( + self.text_emb(text_list), + self.proms_emb(proms_list), + self.resps_emb(self._unsqueeze_list(resps_list)), + sep=self.sep, + ) + + x, m = list_to_tensor(x_list) + + xs = x[:, start+n:start+(n+1), :] + r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True) + r = self.classifier(r) * m + + logits = torch.stack([hi[-1] for hi in r]) + r = Categorical(logits=logits / sampling_temperature).sample() + + # append outputted token + for i, ri in enumerate(r): + resps_list[i] = torch.cat([resps_list[i], ri[None]]) + + # stop token found + stopped |= r == self.stop_token + if stopped.all().item(): + break pruned = [self._prune(r) for r in resps_list] return pruned diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 03c6c18..f149065 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -149,6 +149,8 @@ class Base(nn.Module): self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) self.sep = nn.Parameter(torch.randn(d_model)) + + self.causal_chunk_size = 0 # 64 if self.causal else 1 if self.arch_type == "transformer": self.sin_emb = SinusoidalEmbedding(d_model) @@ -171,8 +173,8 @@ class Base(nn.Module): dropout=p_dropout, checkpoint_activations=True, - chunkwise_recurrent=False, # self.causal, - recurrent_chunkwise_size=64, + chunkwise_recurrent=self.causal and self.causal_chunk_size > 0, + recurrent_chunkwise_size=self.causal_chunk_size, no_output_layer=True, decoder_normalize_before=True, )) @@ -358,6 +360,10 @@ class Base(nn.Module): elif return_all_resp: logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))] ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ] + # return the last chunkwise piece + elif self.causal_chunk_size > 0: + logits = [hi[-self.causal_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))] + ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ] # return just the last code else: logits = torch.stack([hi[-1] for hi in h_list])