diff --git a/vall_e/config.py b/vall_e/config.py
index 0bfb0e9..16febda 100755
--- a/vall_e/config.py
+++ b/vall_e/config.py
@@ -27,6 +27,10 @@ class _Config:
def relpath(self):
return Path(self.cfg_path)
+ @property
+ def cache_dir(self):
+ return self.relpath / ".cache"
+
@property
def ckpt_dir(self):
return self.relpath / "ckpt"
@@ -119,6 +123,7 @@ class Dataset:
hdf5_name: str = "data.h5"
use_hdf5: bool = False
+ use_metadata: bool = False
hdf5_flag: str = "a"
validate: bool = True
workers: int = 8
@@ -135,6 +140,19 @@ class Dataset:
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
+ @property
+ def min_phones(self):
+ return self.phones_range[0]
+ @property
+ def max_phones(self):
+ return self.phones_range[1]
+ @property
+ def min_duration(self):
+ return self.duration_range[0]
+ @property
+ def max_duration(self):
+ return self.duration_range[1]
+
@dataclass()
class Model:
name: str = ""
@@ -393,7 +411,7 @@ class Trainer:
weight_dtype: str = "float16"
- backend: str = "deepspeed" if not sys.platform.startswith("win") else "local"
+ backend: str = "local" # "deepspeed" if not sys.platform.startswith("win") else "local"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
@@ -453,10 +471,6 @@ class Config(_Config):
def get_spkr(self):
return eval(self.dataset.speaker_name_getter)
- @property
- def cache_dir(self):
- return ".cache" / self.relpath
-
@cached_property
def diskcache(self):
if self.cfg_path is not None and self.dataset.cache:
@@ -501,11 +515,10 @@ try:
if cfg.dataset.use_hdf5:
cfg.load_hdf5()
- if not cfg.dataset.use_hdf5:
- cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
- cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
-
+ cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
+ cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
+
except Exception as e:
pass
diff --git a/vall_e/data.py b/vall_e/data.py
index 6b85872..b062d60 100755
--- a/vall_e/data.py
+++ b/vall_e/data.py
@@ -8,6 +8,7 @@ import numpy as np
import os
import random
import torch
+import itertools
from .config import cfg
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
@@ -31,7 +32,7 @@ def get_phone_symmap():
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
- symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
+ symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185}
return symmap
def get_task_symmap():
@@ -51,24 +52,89 @@ def get_task_symmap():
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
-def _get_hdf5_path(path):
- path = str(path)
- if path[:2] != "./":
- path = f'./{path}'
- return path.replace(cfg.cfg_path, "")
-
def _get_quant_path(path):
return _replace_file_extension(path, ".qnt.pt")
def _get_phone_path(path):
return _replace_file_extension(path, ".phn.txt")
+def _load_paths(dataset, type="training"):
+ return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
+
+"""
+def _load_paths_from_hdf5(dataset, type="training"):
+ return { cfg.get_spkr( data_dir / "dummy" ): _get_hdf5_paths( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
+
+def _load_paths_from_disk(dataset, type="training"):
+ return { cfg.get_spkr( data_dir / "dummy" ): _get_paths_of_extensions( data_dir, ".qnt.pt", validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
+"""
+
+def _load_paths_from_metadata(data_dir, type="training", validate=False):
+ _fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
+
+ def _validate( entry ):
+ phones = entry['phones']
+ duration = entry['duration']
+ return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
+
+ metadata_path = data_dir / "metadata.json"
+ if not cfg.dataset.use_metadata or not metadata_path.exists():
+ return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
+
+ speaker = cfg.get_spkr( data_dir / "dummy" )
+ metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
+
+ def key( dir, id ):
+ if not cfg.dataset.use_hdf5:
+ return data_dir / id
+
+ return f"/{type}{_get_hdf5_path(data_dir)}/{id}"
+
+ return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
+
+
+def _get_hdf5_path(path):
+ path = str(path)
+ if path[:2] != "./":
+ path = f'./{path}'
+ return path.replace(cfg.cfg_path, "")
+
+def _get_hdf5_paths( data_dir, type="training", validate=False ):
+ data_dir = str(data_dir)
+
+ def _validate(child):
+ phones = child.attrs['phonemes']
+ duration = child.attrs['duration']
+ return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
+
+ key = f"/{type}{_get_hdf5_path(data_dir)}"
+ return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
+
+def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ):
+ if isinstance(path, str):
+ path = Path(path)
+
+ def _validate(path):
+ if "".join(path.suffixes) not in extensions:
+ return False
+ if not _get_phone_path(path).exists() or not _get_quant_path(path).exists():
+ return False
+ if not validate:
+ return True
+ # to-do: find an easy way to determine size from pickled quants without loading
+ # to-do: find a consistent way to derive phoneme count from filesize (probably can't due to utf-8)
+ phones = len(_get_phones(_get_phone_path(path))) # _get_phone_path(path).stat().st_size // 2 + 1
+ return cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
+
+
+ return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
+
def _load_quants(path) -> Tensor:
- return torch.load(path)[0][:, :].t().to(torch.int16)
+ return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
@cache
def _get_phones(path, language="en"):
- content = open(_get_phone_path(path), "r", encoding="utf8").read().split(" ")
+ content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
return [""] + [ " " if not p else p for p in content ] + [""]
def _interleaved_reorder(l, fn):
@@ -81,114 +147,61 @@ def _interleaved_reorder(l, fn):
if value is not None:
yield value
-
-@cache
-def _validate(path, min_phones, max_phones, min_duration, max_duration):
- if cfg.dataset.use_hdf5:
- key = _get_hdf5_path(path)
- if key not in cfg.hdf5:
- return False
-
- phones = cfg.hdf5[key].attrs['phonemes']
- duration = cfg.hdf5[key].attrs['duration']
-
- if phones < min_phones or phones > max_phones:
- return False
-
- if duration < min_duration or duration > max_duration:
- return False
-
- return True
-
- if not os.path.exists(_get_phone_path(path)) or not os.path.exists(_get_quant_path(path)):
- return False
-
- phones = _get_phones(path)
- unique_phones = list(set(phones))
-
- if len(unique_phones) == 0:
- return False
- if len(unique_phones) == 1 and unique_phones[0] == " ":
- return False
- if len(phones) < min_phones or len(phones) > max_phones:
- return False
- return True
-
class Dataset(_Dataset):
def __init__(
self,
- paths,
phone_symmap=None,
training=False,
extra_paths_by_spkr_name: dict[str, list] = {},
):
super().__init__()
self._head = None
- self.min_phones = cfg.dataset.phones_range[0]
- self.max_phones = cfg.dataset.phones_range[1]
- self.min_duration = cfg.dataset.duration_range[0]
- self.max_duration = cfg.dataset.duration_range[1]
self.sampler = None
- if cfg.dataset.validate:
- self.paths = [
- path for path in paths if _validate(path, self.min_phones, self.max_phones, self.min_duration, self.max_duration)
- ]
- else:
- self.paths = paths
+ self.paths = []
+
+ self.training = training
+ self.dataset_type = "training" if self.training else "validation"
+ self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
+
+ self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
+ self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
+
+ if cfg.dataset.sample_type == "path":
+ self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
+
+ self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
+ self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
self.phone_symmap = phone_symmap or self._get_phone_symmap()
self.spkr_symmap = self._get_spkr_symmap()
self.task_symmap = self._get_task_symmap()
- self.training = training
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
- self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
-
- if cfg.dataset.validate:
- self.paths = [
- p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
- ]
-
- if cfg.dataset.sample_type == "path":
- self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)]
-
if len(self.paths) == 0 and training:
raise ValueError("No valid path is found for training.")
+ # would be a better cost saving if we could fetch the duration during the validation pass but oh well
self.duration = 0
- self.durations = {}
if cfg.dataset.use_hdf5:
for path in self.paths:
- key = _get_hdf5_path(path)
- spkr_name = cfg.get_spkr(path)
- spkr_id = self.spkr_symmap[spkr_name]
- duration = cfg.hdf5[key].attrs['duration']
-
- self.duration += duration
-
- if spkr_id not in self.durations:
- self.durations[spkr_id] = duration
- else:
- self.durations[spkr_id] += duration
-
- def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
- ret = defaultdict(list)
- for path in self.paths:
- ret[cfg.get_spkr(path)].append(path)
- for k, v in extra_paths_by_spkr_name.items():
- ret[k].extend(v)
- return {**ret}
+ self.duration += cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
@cached_property
def phones(self):
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
+ def get_speaker(self, path):
+ if isinstance(path, str):
+ path = Path(path)
+ res = cfg.get_spkr(path)
+ return res
+
@cached_property
def spkrs(self):
- return sorted({cfg.get_spkr(path) for path in self.paths})
+ return sorted({self.get_speaker(path) for path in self.paths})
@cached_property
def tasks(self):
@@ -209,13 +222,10 @@ class Dataset(_Dataset):
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
def sample_noise(self):
- paths = []
- for data_dir in cfg.dataset.noise:
- paths.extend(data_dir.rglob("*.qnt.pt"))
- path = random.choice(paths)
+ path = random.choice(self.noise_paths)
- if False and cfg.dataset.use_hdf5:
- key = f'/noise/{_get_hdf5_path(path)}'
+ if cfg.dataset.use_hdf5:
+ key = _get_hdf5_path(path)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
qnt = _load_quants(path)
@@ -275,7 +285,7 @@ class Dataset(_Dataset):
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
else:
path = self.paths[index]
- spkr_name = cfg.get_spkr(path)
+ spkr_name = self.get_speaker(path)
spkr_id = self.spkr_symmap[spkr_name]
if cfg.dataset.use_hdf5:
@@ -507,110 +517,10 @@ def _create_dataloader(dataset, training):
sampler=sampler,
)
-def _load_dataset_paths():
- hf = cfg.hdf5
- paths = {
- "training": [],
- "validation": [],
- }
-
- datasets = {
- "training": [],
- "validation": [],
- }
-
- def get_paths( data_dir, type="training" ):
- key = f"/{type}{_get_hdf5_path(data_dir)}"
- if key not in cfg.hdf5:
- return
-
- paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ])
-
- for data_dir in cfg.dataset.training:
- get_paths( data_dir, "training" )
-
- for data_dir in cfg.dataset.validation:
- get_paths( data_dir, "validation" )
-
- for _, type in enumerate(paths):
- dirs = paths[type]
-
- if len(dirs) == 0:
- continue
-
- dirs = [ Path(p) for p in dirs ]
-
- pairs = sorted([(cfg.get_spkr(p), p) for p in dirs])
- for _, group in groupby(pairs, lambda pair: pair[0]):
- shuffled = sorted([p for _, p in group])
- random.seed(0)
- random.shuffle(shuffled)
-
- datasets[type].extend(shuffled)
-
- return datasets["training"], datasets["validation"]
-
-# to-do: mirror the hdf5-based load function
-def _load_train_val_paths():
- paths = []
- train_paths = []
- val_paths = []
-
- for data_dir in cfg.dataset.training:
- paths.extend(data_dir.rglob("*.qnt.pt"))
-
- if len(paths) > 0:
- pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
- del paths
-
- for _, group in groupby(pairs, lambda pair: pair[0]):
- paths = sorted([p for _, p in group])
- random.seed(0)
- random.shuffle(paths)
- train_paths.extend(paths)
-
- for data_dir in cfg.dataset.validation:
- paths.extend(data_dir.rglob("*.qnt.pt"))
-
- if len(paths) > 0:
- pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
- del paths
-
- for _, group in groupby(pairs, lambda pair: pair[0]):
- paths = sorted([p for _, p in group])
- random.seed(0)
- random.shuffle(paths)
- val_paths.extend(paths)
-
- train_paths, val_paths = map(sorted, [train_paths, val_paths])
-
- if len(train_paths) == 0:
- raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.")
-
- # to-do: allow setting aside a fixed portion of the training dataset for validation
- # something like the last X percent of each speaker is set aside
- if len(val_paths) == 0:
- val_paths = [ train_paths[0] ]
-
- return train_paths, val_paths
-
@cfg.diskcache()
def create_datasets():
- train_paths, val_paths = _load_dataset_paths() if cfg.dataset.use_hdf5 else _load_train_val_paths()
-
- train_dataset = Dataset(
- train_paths,
- training=True,
- )
-
- val_dataset = Dataset(
- val_paths,
- train_dataset.phone_symmap,
- #train_dataset.spkr_symmap,
- #extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
- )
-
- val_dataset.head_(cfg.evaluation.size)
+ train_dataset = Dataset( training=True )
+ val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
return train_dataset, val_dataset
@@ -628,17 +538,10 @@ def create_train_val_dataloader():
_logger.info(str(train_dataset.phone_symmap))
_logger.info(str(train_dataset.spkr_symmap))
-
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
-
- """
- _logger.info(f"#durations (train): {str(train_dataset.durations)}.")
- _logger.info(f"#durations (val): {str(val_dataset.durations)}.")
- _logger.info(f"#durations (subtrain): {str(subtrain_dataset.durations)}.")
- """
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
@@ -648,8 +551,52 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl
+# parse dataset into better to sample metadata
+def create_dataset_metadata():
+ cfg.dataset.validate = False
+ cfg.dataset.use_hdf5 = False
+
+ paths_by_spkr_name = {}
+
+ paths_by_spkr_name |= _load_paths(cfg.dataset.training, "training")
+ paths_by_spkr_name |= _load_paths(cfg.dataset.validation, "validation")
+ paths_by_spkr_name |= _load_paths(cfg.dataset.noise, "noise")
+
+ paths = list(itertools.chain.from_iterable(paths_by_spkr_name.values()))
+
+ metadata = {}
+ for path in tqdm(paths, desc="Parsing paths"):
+ speaker = cfg.get_spkr(path)
+ if speaker not in metadata:
+ metadata[speaker] = {}
+
+ if cfg.dataset.use_hdf5:
+ phones = cfg.hdf5[_get_hdf5_path(path)].attrs['phonemes']
+ duration = cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
+ else:
+ phns_path = _get_phone_path(path)
+ qnts_path = _get_quant_path(path)
+
+ phones = len(_get_phones(phns_path)) if phns_path.exists() else 0
+ duration = _load_quants(qnts_path).shape[0] / 75 if qnts_path.exists() else 0
+
+
+ metadata[speaker][path.name.split(".")[0]] = {
+ "phones": phones,
+ "duration": duration
+ }
+
+ for speaker, paths in tqdm(paths_by_spkr_name.items(), desc="Writing metadata"):
+ if len(paths) == 0:
+ continue
+ with open(paths[0].parent / "metadata.json", "w", encoding="utf-8") as f:
+ f.write( json.dumps( metadata[speaker] ) )
+
+ with open(cfg.relpath / "metadata.json", "w", encoding="utf-8") as f:
+ f.write( json.dumps( metadata ) )
+
# parse yaml to create an hdf5 file
-def create_dataset_hdf5():
+def create_dataset_hdf5( skip_existing=True ):
cfg.dataset.use_hdf5 = True
cfg.load_hdf5(write=True)
@@ -658,11 +605,11 @@ def create_dataset_hdf5():
root = cfg.cfg_path
hf = cfg.hdf5
- def add( dir, type="training", audios=True, texts=True ):
- dir = "./" + str(dir)
- name = dir.replace(root, "")
- print( str(dir), name )
+ def add( dir, type="training", audios=True, texts=True ):
+ name = "./" + str(dir)
+ name = name .replace(root, "")
+ metadata = {}
if not os.path.isdir(f'{root}/{name}/'):
return
@@ -680,35 +627,53 @@ def create_dataset_hdf5():
key = f'{type}/{name}/{id}'
if key in hf:
- # print("Skipping existing entry:", key)
- continue
+ if skip_existing:
+ continue
+ del hf[key]
group = hf.create_group(key)
+ group.attrs['id'] = id
+ group.attrs['type'] = type
+ group.attrs['speaker'] = name
+
+ metadata[id] = {}
# audio
if audios:
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
+
+ if "audio" in group:
+ del group["audio"]
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
+ group.attrs['duration'] = qnt.shape[0] / 75
+ metadata[id]["duration"] = qnt.shape[0] / 75
+ else:
+ group.attrs['duration'] = 0
+ metadata[id]["duration"] = 0
# text
if texts:
- with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
- content = f.read()
- split = content.split(" ")
- phones = [f""] + [ " " if not p else p for p in split ] + [f""]
+ with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") as f:
+ content = f.read().split(" ")
+ phones = [f""] + [ " " if not p else p for p in content ] + [f""]
for s in set(phones):
if s not in symmap:
symmap[s] = len(symmap.keys())
+
phn = [ symmap[s] for s in phones ]
+ if "text" in group:
+ del group["text"]
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
-
- # metadata
- group.attrs['id'] = id
- group.attrs['type'] = type
- group.attrs['speaker'] = name
- group.attrs['duration'] = qnt.shape[0] / 75
group.attrs['phonemes'] = len(phn)
+ metadata[id]["phones"] = len(phn)
+ else:
+ group.attrs['phonemes'] = 0
+ metadata[id]["phones"] = 0
+
+ with open(dir / "metadata.json", "w", encoding="utf-8") as f:
+ f.write( json.dumps( metadata ) )
+
# training
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
@@ -723,10 +688,9 @@ def create_dataset_hdf5():
add( data_dir, type="noise", texts=False )
# write symmap
- try:
- hf.create_dataset('symmap', data=json.dumps(symmap))
- except Exception as e:
- pass
+ if "symmap" in hf:
+ del hf['symmap']
+ hf.create_dataset('symmap', data=json.dumps(symmap))
hf.close()
@@ -742,8 +706,16 @@ if __name__ == "__main__":
cfg.dataset.workers = 1
+ class LoggerOveride:
+ def info(self, *args):
+ print(*args)
+
+ _logger = LoggerOveride()
+
if args.action == "hdf5":
create_dataset_hdf5()
+ elif args.action == "metadata":
+ create_dataset_metadata()
elif args.action == "sample":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
@@ -758,6 +730,7 @@ if __name__ == "__main__":
del v[i]['proms']
del v[i]['resps']
print(f'{k}:', v)
+
elif args.action == "tasks":
index = 0
cfg.dataset.tasks_list = args.tasks.split(",")
diff --git a/vall_e/emb/g2p.py b/vall_e/emb/g2p.py
index a5ea3b6..3c64536 100755
--- a/vall_e/emb/g2p.py
+++ b/vall_e/emb/g2p.py
@@ -33,10 +33,13 @@ def _get_backend( language="en-us", backend="espeak" ):
return phonemizer
-def encode(text: str, language="en-us", backend="espeak") -> list[str]:
+def encode(text: str, language="en-us", backend="auto") -> list[str]:
if language == "en":
language = "en-us"
+ if not backend or backend == "auto":
+ backend = "espeak" # if language[:2] != "en" else "festival"
+
text = [ text ]
backend = _get_backend(language=language, backend=backend)
@@ -63,16 +66,13 @@ def main():
args = parser.parse_args()
paths = list(args.folder.rglob(f"*{args.suffix}"))
- random.shuffle(paths)
for path in tqdm(paths):
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
if phone_path.exists():
continue
- graphs = _get_graphs(path)
- phones = encode(graphs)
- with open(phone_path, "w") as f:
- f.write(" ".join(phones))
+ phones = encode(open(path, "r", encoding="utf-8").read())
+ open(phone_path, "w", encoding="utf-8").write(" ".join(phones))
if __name__ == "__main__":
diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py
index 25bf618..03f9983 100755
--- a/vall_e/emb/qnt.py
+++ b/vall_e/emb/qnt.py
@@ -213,7 +213,7 @@ def repeat_extend_audio( qnt, target ):
pieces.append(qnt)
length += qnt.shape[0]
- return trim_random(torch.cat(pieces), target)
+ return trim(torch.cat(pieces), target)
# merges two quantized audios together
# I don't know if this works
diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py
index 46975a1..0d370db 100755
--- a/vall_e/models/ar.py
+++ b/vall_e/models/ar.py
@@ -2,6 +2,7 @@ from ..config import cfg
from .base import Base, list_to_tensor, Categorical
import torch
+from torch.nn.utils.rnn import pad_sequence
from einops import rearrange
from torch import Tensor
@@ -62,7 +63,7 @@ class AR(Base):
max_steps: int = 1000,
sampling_temperature: float = 1.0,
- naive: bool = True,
+ naive: bool = False,
):
if resps_list is not None:
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
@@ -83,30 +84,104 @@ class AR(Base):
]
stopped = torch.zeros(len(text_list), device=device).bool()
- chunk_size = 1 # don't really know what to do about this desu
+ chunk_size = self.causal_chunk_size # don't really know what to do about this desu
state = None
start = 0
- for n in trange(max_steps // chunk_size):
- # get next in sequence
+ if naive:
+ for n in trange(max_steps // max(1, chunk_size)):
+ # get next in sequence
- r, state = super().forward(
- text_list,
- proms_list,
- self._unsqueeze_list(resps_list),
- sampling_temperature=sampling_temperature,
- state=state if not naive else None,
+ r, state = super().forward(
+ text_list,
+ proms_list,
+ self._unsqueeze_list(resps_list),
+ sampling_temperature=sampling_temperature,
+ state=state # if not naive else None,
+ )
+
+ # append outputted token
+ if self.causal_chunk_size > 0:
+ for i, ri in enumerate(r):
+ resps_list[i] = torch.cat([resps_list[i], ri])
+ else:
+ for i, ri in enumerate(r):
+ resps_list[i] = torch.cat([resps_list[i], ri[None]])
+
+
+ # stop token found
+ stopped |= r == self.stop_token
+ if stopped.all().item():
+ break
+ # to-do: make it work
+ # it seems anything that isn't a one-at-a-time sequence does not work, despite generating STOP tokens.
+ else:
+ resps_list: list[Tensor] = [
+ torch.zeros(0, device=device).to(torch.int16) for _ in text_list
+ ]
+
+ test_list: list[Tensor] = [
+ torch.zeros(0, device=device).to(torch.int16) for _ in text_list
+ ]
+
+ batch_size = len(text_list)
+
+ x_list = self._samplewise_merge_tensors(
+ self.text_emb(text_list),
+ self.proms_emb(proms_list),
+ self.resps_emb(self._unsqueeze_list(resps_list)),
+ sep=self.sep,
)
+ x, m = list_to_tensor(x_list)
+ device = x.device
+
+ if state is None:
+ state = {}
+
+ # pre-fill KV cache
+ for n in trange(x.shape[1]):
+ xs = x[:, n:(n + 1), :]
+ r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
+ r = self.classifier(r) * m
+
+ logits = torch.stack([hi[-1] for hi in r])
+ r = Categorical(logits=logits / sampling_temperature).sample()
+
+ for i, ri in enumerate(r):
+ test_list[i] = torch.cat([test_list[i], ri[None]])
+
# append outputted token
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri[None]])
- # stop token found
- stopped |= r == self.stop_token
- if stopped.all().item():
- break
+ start = x.shape[1]
+ for n in trange(max_steps // max(1, chunk_size)):
+ x_list = self._samplewise_merge_tensors(
+ self.text_emb(text_list),
+ self.proms_emb(proms_list),
+ self.resps_emb(self._unsqueeze_list(resps_list)),
+ sep=self.sep,
+ )
+
+ x, m = list_to_tensor(x_list)
+
+ xs = x[:, start+n:start+(n+1), :]
+ r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
+ r = self.classifier(r) * m
+
+ logits = torch.stack([hi[-1] for hi in r])
+ r = Categorical(logits=logits / sampling_temperature).sample()
+
+ # append outputted token
+ for i, ri in enumerate(r):
+ resps_list[i] = torch.cat([resps_list[i], ri[None]])
+
+ # stop token found
+ stopped |= r == self.stop_token
+ if stopped.all().item():
+ break
pruned = [self._prune(r) for r in resps_list]
return pruned
diff --git a/vall_e/models/base.py b/vall_e/models/base.py
index 03c6c18..f149065 100755
--- a/vall_e/models/base.py
+++ b/vall_e/models/base.py
@@ -149,6 +149,8 @@ class Base(nn.Module):
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
self.sep = nn.Parameter(torch.randn(d_model))
+
+ self.causal_chunk_size = 0 # 64 if self.causal else 1
if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model)
@@ -171,8 +173,8 @@ class Base(nn.Module):
dropout=p_dropout,
checkpoint_activations=True,
- chunkwise_recurrent=False, # self.causal,
- recurrent_chunkwise_size=64,
+ chunkwise_recurrent=self.causal and self.causal_chunk_size > 0,
+ recurrent_chunkwise_size=self.causal_chunk_size,
no_output_layer=True,
decoder_normalize_before=True,
))
@@ -358,6 +360,10 @@ class Base(nn.Module):
elif return_all_resp:
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
+ # return the last chunkwise piece
+ elif self.causal_chunk_size > 0:
+ logits = [hi[-self.causal_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
+ ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
# return just the last code
else:
logits = torch.stack([hi[-1] for hi in h_list])