overhauled dataloading code to be marginally faster, mostly cleaned up, and can leverage a metadata json to help things out
This commit is contained in:
parent
7b3be3d7bf
commit
78378ed1ce
|
@ -27,6 +27,10 @@ class _Config:
|
|||
def relpath(self):
|
||||
return Path(self.cfg_path)
|
||||
|
||||
@property
|
||||
def cache_dir(self):
|
||||
return self.relpath / ".cache"
|
||||
|
||||
@property
|
||||
def ckpt_dir(self):
|
||||
return self.relpath / "ckpt"
|
||||
|
@ -119,6 +123,7 @@ class Dataset:
|
|||
|
||||
hdf5_name: str = "data.h5"
|
||||
use_hdf5: bool = False
|
||||
use_metadata: bool = False
|
||||
hdf5_flag: str = "a"
|
||||
validate: bool = True
|
||||
workers: int = 8
|
||||
|
@ -135,6 +140,19 @@ class Dataset:
|
|||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||
|
||||
@property
|
||||
def min_phones(self):
|
||||
return self.phones_range[0]
|
||||
@property
|
||||
def max_phones(self):
|
||||
return self.phones_range[1]
|
||||
@property
|
||||
def min_duration(self):
|
||||
return self.duration_range[0]
|
||||
@property
|
||||
def max_duration(self):
|
||||
return self.duration_range[1]
|
||||
|
||||
@dataclass()
|
||||
class Model:
|
||||
name: str = ""
|
||||
|
@ -393,7 +411,7 @@ class Trainer:
|
|||
|
||||
weight_dtype: str = "float16"
|
||||
|
||||
backend: str = "deepspeed" if not sys.platform.startswith("win") else "local"
|
||||
backend: str = "local" # "deepspeed" if not sys.platform.startswith("win") else "local"
|
||||
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||
|
||||
|
@ -453,10 +471,6 @@ class Config(_Config):
|
|||
def get_spkr(self):
|
||||
return eval(self.dataset.speaker_name_getter)
|
||||
|
||||
@property
|
||||
def cache_dir(self):
|
||||
return ".cache" / self.relpath
|
||||
|
||||
@cached_property
|
||||
def diskcache(self):
|
||||
if self.cfg_path is not None and self.dataset.cache:
|
||||
|
@ -501,11 +515,10 @@ try:
|
|||
if cfg.dataset.use_hdf5:
|
||||
cfg.load_hdf5()
|
||||
|
||||
if not cfg.dataset.use_hdf5:
|
||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||
|
||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
|
417
vall_e/data.py
417
vall_e/data.py
|
@ -8,6 +8,7 @@ import numpy as np
|
|||
import os
|
||||
import random
|
||||
import torch
|
||||
import itertools
|
||||
|
||||
from .config import cfg
|
||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||
|
@ -31,7 +32,7 @@ def get_phone_symmap():
|
|||
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
|
||||
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
|
||||
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185}
|
||||
return symmap
|
||||
|
||||
def get_task_symmap():
|
||||
|
@ -51,24 +52,89 @@ def get_task_symmap():
|
|||
def _replace_file_extension(path, suffix):
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
def _get_hdf5_path(path):
|
||||
path = str(path)
|
||||
if path[:2] != "./":
|
||||
path = f'./{path}'
|
||||
return path.replace(cfg.cfg_path, "")
|
||||
|
||||
def _get_quant_path(path):
|
||||
return _replace_file_extension(path, ".qnt.pt")
|
||||
|
||||
def _get_phone_path(path):
|
||||
return _replace_file_extension(path, ".phn.txt")
|
||||
|
||||
def _load_paths(dataset, type="training"):
|
||||
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||
|
||||
"""
|
||||
def _load_paths_from_hdf5(dataset, type="training"):
|
||||
return { cfg.get_spkr( data_dir / "dummy" ): _get_hdf5_paths( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||
|
||||
def _load_paths_from_disk(dataset, type="training"):
|
||||
return { cfg.get_spkr( data_dir / "dummy" ): _get_paths_of_extensions( data_dir, ".qnt.pt", validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||
"""
|
||||
|
||||
def _load_paths_from_metadata(data_dir, type="training", validate=False):
|
||||
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
||||
|
||||
def _validate( entry ):
|
||||
phones = entry['phones']
|
||||
duration = entry['duration']
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
metadata_path = data_dir / "metadata.json"
|
||||
if not cfg.dataset.use_metadata or not metadata_path.exists():
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
|
||||
|
||||
speaker = cfg.get_spkr( data_dir / "dummy" )
|
||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||
|
||||
def key( dir, id ):
|
||||
if not cfg.dataset.use_hdf5:
|
||||
return data_dir / id
|
||||
|
||||
return f"/{type}{_get_hdf5_path(data_dir)}/{id}"
|
||||
|
||||
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
|
||||
|
||||
|
||||
def _get_hdf5_path(path):
|
||||
path = str(path)
|
||||
if path[:2] != "./":
|
||||
path = f'./{path}'
|
||||
return path.replace(cfg.cfg_path, "")
|
||||
|
||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||
data_dir = str(data_dir)
|
||||
|
||||
def _validate(child):
|
||||
phones = child.attrs['phonemes']
|
||||
duration = child.attrs['duration']
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
key = f"/{type}{_get_hdf5_path(data_dir)}"
|
||||
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
|
||||
|
||||
def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
def _validate(path):
|
||||
if "".join(path.suffixes) not in extensions:
|
||||
return False
|
||||
if not _get_phone_path(path).exists() or not _get_quant_path(path).exists():
|
||||
return False
|
||||
if not validate:
|
||||
return True
|
||||
# to-do: find an easy way to determine size from pickled quants without loading
|
||||
# to-do: find a consistent way to derive phoneme count from filesize (probably can't due to utf-8)
|
||||
phones = len(_get_phones(_get_phone_path(path))) # _get_phone_path(path).stat().st_size // 2 + 1
|
||||
return cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
|
||||
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
|
||||
|
||||
def _load_quants(path) -> Tensor:
|
||||
return torch.load(path)[0][:, :].t().to(torch.int16)
|
||||
return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
|
||||
|
||||
@cache
|
||||
def _get_phones(path, language="en"):
|
||||
content = open(_get_phone_path(path), "r", encoding="utf8").read().split(" ")
|
||||
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
|
||||
return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
|
||||
|
||||
def _interleaved_reorder(l, fn):
|
||||
|
@ -81,114 +147,61 @@ def _interleaved_reorder(l, fn):
|
|||
if value is not None:
|
||||
yield value
|
||||
|
||||
|
||||
@cache
|
||||
def _validate(path, min_phones, max_phones, min_duration, max_duration):
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
if key not in cfg.hdf5:
|
||||
return False
|
||||
|
||||
phones = cfg.hdf5[key].attrs['phonemes']
|
||||
duration = cfg.hdf5[key].attrs['duration']
|
||||
|
||||
if phones < min_phones or phones > max_phones:
|
||||
return False
|
||||
|
||||
if duration < min_duration or duration > max_duration:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if not os.path.exists(_get_phone_path(path)) or not os.path.exists(_get_quant_path(path)):
|
||||
return False
|
||||
|
||||
phones = _get_phones(path)
|
||||
unique_phones = list(set(phones))
|
||||
|
||||
if len(unique_phones) == 0:
|
||||
return False
|
||||
if len(unique_phones) == 1 and unique_phones[0] == " ":
|
||||
return False
|
||||
if len(phones) < min_phones or len(phones) > max_phones:
|
||||
return False
|
||||
return True
|
||||
|
||||
class Dataset(_Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
paths,
|
||||
phone_symmap=None,
|
||||
training=False,
|
||||
extra_paths_by_spkr_name: dict[str, list] = {},
|
||||
):
|
||||
super().__init__()
|
||||
self._head = None
|
||||
self.min_phones = cfg.dataset.phones_range[0]
|
||||
self.max_phones = cfg.dataset.phones_range[1]
|
||||
self.min_duration = cfg.dataset.duration_range[0]
|
||||
self.max_duration = cfg.dataset.duration_range[1]
|
||||
self.sampler = None
|
||||
|
||||
if cfg.dataset.validate:
|
||||
self.paths = [
|
||||
path for path in paths if _validate(path, self.min_phones, self.max_phones, self.min_duration, self.max_duration)
|
||||
]
|
||||
else:
|
||||
self.paths = paths
|
||||
self.paths = []
|
||||
|
||||
self.training = training
|
||||
self.dataset_type = "training" if self.training else "validation"
|
||||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||
|
||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||
|
||||
if cfg.dataset.sample_type == "path":
|
||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||
|
||||
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
||||
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
|
||||
|
||||
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
||||
self.spkr_symmap = self._get_spkr_symmap()
|
||||
self.task_symmap = self._get_task_symmap()
|
||||
self.training = training
|
||||
|
||||
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
||||
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
||||
|
||||
self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
|
||||
|
||||
if cfg.dataset.validate:
|
||||
self.paths = [
|
||||
p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
|
||||
]
|
||||
|
||||
if cfg.dataset.sample_type == "path":
|
||||
self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)]
|
||||
|
||||
if len(self.paths) == 0 and training:
|
||||
raise ValueError("No valid path is found for training.")
|
||||
|
||||
# would be a better cost saving if we could fetch the duration during the validation pass but oh well
|
||||
self.duration = 0
|
||||
self.durations = {}
|
||||
if cfg.dataset.use_hdf5:
|
||||
for path in self.paths:
|
||||
key = _get_hdf5_path(path)
|
||||
spkr_name = cfg.get_spkr(path)
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
duration = cfg.hdf5[key].attrs['duration']
|
||||
|
||||
self.duration += duration
|
||||
|
||||
if spkr_id not in self.durations:
|
||||
self.durations[spkr_id] = duration
|
||||
else:
|
||||
self.durations[spkr_id] += duration
|
||||
|
||||
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
|
||||
ret = defaultdict(list)
|
||||
for path in self.paths:
|
||||
ret[cfg.get_spkr(path)].append(path)
|
||||
for k, v in extra_paths_by_spkr_name.items():
|
||||
ret[k].extend(v)
|
||||
return {**ret}
|
||||
self.duration += cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
|
||||
|
||||
@cached_property
|
||||
def phones(self):
|
||||
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
||||
|
||||
def get_speaker(self, path):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
res = cfg.get_spkr(path)
|
||||
return res
|
||||
|
||||
@cached_property
|
||||
def spkrs(self):
|
||||
return sorted({cfg.get_spkr(path) for path in self.paths})
|
||||
return sorted({self.get_speaker(path) for path in self.paths})
|
||||
|
||||
@cached_property
|
||||
def tasks(self):
|
||||
|
@ -209,13 +222,10 @@ class Dataset(_Dataset):
|
|||
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
|
||||
|
||||
def sample_noise(self):
|
||||
paths = []
|
||||
for data_dir in cfg.dataset.noise:
|
||||
paths.extend(data_dir.rglob("*.qnt.pt"))
|
||||
path = random.choice(paths)
|
||||
path = random.choice(self.noise_paths)
|
||||
|
||||
if False and cfg.dataset.use_hdf5:
|
||||
key = f'/noise/{_get_hdf5_path(path)}'
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||
else:
|
||||
qnt = _load_quants(path)
|
||||
|
@ -275,7 +285,7 @@ class Dataset(_Dataset):
|
|||
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
||||
else:
|
||||
path = self.paths[index]
|
||||
spkr_name = cfg.get_spkr(path)
|
||||
spkr_name = self.get_speaker(path)
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
|
@ -507,110 +517,10 @@ def _create_dataloader(dataset, training):
|
|||
sampler=sampler,
|
||||
)
|
||||
|
||||
def _load_dataset_paths():
|
||||
hf = cfg.hdf5
|
||||
paths = {
|
||||
"training": [],
|
||||
"validation": [],
|
||||
}
|
||||
|
||||
datasets = {
|
||||
"training": [],
|
||||
"validation": [],
|
||||
}
|
||||
|
||||
def get_paths( data_dir, type="training" ):
|
||||
key = f"/{type}{_get_hdf5_path(data_dir)}"
|
||||
if key not in cfg.hdf5:
|
||||
return
|
||||
|
||||
paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ])
|
||||
|
||||
for data_dir in cfg.dataset.training:
|
||||
get_paths( data_dir, "training" )
|
||||
|
||||
for data_dir in cfg.dataset.validation:
|
||||
get_paths( data_dir, "validation" )
|
||||
|
||||
for _, type in enumerate(paths):
|
||||
dirs = paths[type]
|
||||
|
||||
if len(dirs) == 0:
|
||||
continue
|
||||
|
||||
dirs = [ Path(p) for p in dirs ]
|
||||
|
||||
pairs = sorted([(cfg.get_spkr(p), p) for p in dirs])
|
||||
for _, group in groupby(pairs, lambda pair: pair[0]):
|
||||
shuffled = sorted([p for _, p in group])
|
||||
random.seed(0)
|
||||
random.shuffle(shuffled)
|
||||
|
||||
datasets[type].extend(shuffled)
|
||||
|
||||
return datasets["training"], datasets["validation"]
|
||||
|
||||
# to-do: mirror the hdf5-based load function
|
||||
def _load_train_val_paths():
|
||||
paths = []
|
||||
train_paths = []
|
||||
val_paths = []
|
||||
|
||||
for data_dir in cfg.dataset.training:
|
||||
paths.extend(data_dir.rglob("*.qnt.pt"))
|
||||
|
||||
if len(paths) > 0:
|
||||
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
|
||||
del paths
|
||||
|
||||
for _, group in groupby(pairs, lambda pair: pair[0]):
|
||||
paths = sorted([p for _, p in group])
|
||||
random.seed(0)
|
||||
random.shuffle(paths)
|
||||
train_paths.extend(paths)
|
||||
|
||||
for data_dir in cfg.dataset.validation:
|
||||
paths.extend(data_dir.rglob("*.qnt.pt"))
|
||||
|
||||
if len(paths) > 0:
|
||||
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
|
||||
del paths
|
||||
|
||||
for _, group in groupby(pairs, lambda pair: pair[0]):
|
||||
paths = sorted([p for _, p in group])
|
||||
random.seed(0)
|
||||
random.shuffle(paths)
|
||||
val_paths.extend(paths)
|
||||
|
||||
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
||||
|
||||
if len(train_paths) == 0:
|
||||
raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.")
|
||||
|
||||
# to-do: allow setting aside a fixed portion of the training dataset for validation
|
||||
# something like the last X percent of each speaker is set aside
|
||||
if len(val_paths) == 0:
|
||||
val_paths = [ train_paths[0] ]
|
||||
|
||||
return train_paths, val_paths
|
||||
|
||||
@cfg.diskcache()
|
||||
def create_datasets():
|
||||
train_paths, val_paths = _load_dataset_paths() if cfg.dataset.use_hdf5 else _load_train_val_paths()
|
||||
|
||||
train_dataset = Dataset(
|
||||
train_paths,
|
||||
training=True,
|
||||
)
|
||||
|
||||
val_dataset = Dataset(
|
||||
val_paths,
|
||||
train_dataset.phone_symmap,
|
||||
#train_dataset.spkr_symmap,
|
||||
#extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
||||
)
|
||||
|
||||
val_dataset.head_(cfg.evaluation.size)
|
||||
train_dataset = Dataset( training=True )
|
||||
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
||||
|
||||
return train_dataset, val_dataset
|
||||
|
||||
|
@ -628,17 +538,10 @@ def create_train_val_dataloader():
|
|||
|
||||
_logger.info(str(train_dataset.phone_symmap))
|
||||
_logger.info(str(train_dataset.spkr_symmap))
|
||||
|
||||
|
||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
|
||||
|
||||
"""
|
||||
_logger.info(f"#durations (train): {str(train_dataset.durations)}.")
|
||||
_logger.info(f"#durations (val): {str(val_dataset.durations)}.")
|
||||
_logger.info(f"#durations (subtrain): {str(subtrain_dataset.durations)}.")
|
||||
"""
|
||||
|
||||
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
||||
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
||||
|
@ -648,8 +551,52 @@ def create_train_val_dataloader():
|
|||
|
||||
return train_dl, subtrain_dl, val_dl
|
||||
|
||||
# parse dataset into better to sample metadata
|
||||
def create_dataset_metadata():
|
||||
cfg.dataset.validate = False
|
||||
cfg.dataset.use_hdf5 = False
|
||||
|
||||
paths_by_spkr_name = {}
|
||||
|
||||
paths_by_spkr_name |= _load_paths(cfg.dataset.training, "training")
|
||||
paths_by_spkr_name |= _load_paths(cfg.dataset.validation, "validation")
|
||||
paths_by_spkr_name |= _load_paths(cfg.dataset.noise, "noise")
|
||||
|
||||
paths = list(itertools.chain.from_iterable(paths_by_spkr_name.values()))
|
||||
|
||||
metadata = {}
|
||||
for path in tqdm(paths, desc="Parsing paths"):
|
||||
speaker = cfg.get_spkr(path)
|
||||
if speaker not in metadata:
|
||||
metadata[speaker] = {}
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
phones = cfg.hdf5[_get_hdf5_path(path)].attrs['phonemes']
|
||||
duration = cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
|
||||
else:
|
||||
phns_path = _get_phone_path(path)
|
||||
qnts_path = _get_quant_path(path)
|
||||
|
||||
phones = len(_get_phones(phns_path)) if phns_path.exists() else 0
|
||||
duration = _load_quants(qnts_path).shape[0] / 75 if qnts_path.exists() else 0
|
||||
|
||||
|
||||
metadata[speaker][path.name.split(".")[0]] = {
|
||||
"phones": phones,
|
||||
"duration": duration
|
||||
}
|
||||
|
||||
for speaker, paths in tqdm(paths_by_spkr_name.items(), desc="Writing metadata"):
|
||||
if len(paths) == 0:
|
||||
continue
|
||||
with open(paths[0].parent / "metadata.json", "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata[speaker] ) )
|
||||
|
||||
with open(cfg.relpath / "metadata.json", "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
|
||||
# parse yaml to create an hdf5 file
|
||||
def create_dataset_hdf5():
|
||||
def create_dataset_hdf5( skip_existing=True ):
|
||||
cfg.dataset.use_hdf5 = True
|
||||
cfg.load_hdf5(write=True)
|
||||
|
||||
|
@ -658,11 +605,11 @@ def create_dataset_hdf5():
|
|||
root = cfg.cfg_path
|
||||
hf = cfg.hdf5
|
||||
|
||||
def add( dir, type="training", audios=True, texts=True ):
|
||||
dir = "./" + str(dir)
|
||||
name = dir.replace(root, "")
|
||||
|
||||
print( str(dir), name )
|
||||
def add( dir, type="training", audios=True, texts=True ):
|
||||
name = "./" + str(dir)
|
||||
name = name .replace(root, "")
|
||||
metadata = {}
|
||||
|
||||
if not os.path.isdir(f'{root}/{name}/'):
|
||||
return
|
||||
|
@ -680,35 +627,53 @@ def create_dataset_hdf5():
|
|||
|
||||
key = f'{type}/{name}/{id}'
|
||||
if key in hf:
|
||||
# print("Skipping existing entry:", key)
|
||||
continue
|
||||
if skip_existing:
|
||||
continue
|
||||
del hf[key]
|
||||
|
||||
group = hf.create_group(key)
|
||||
group.attrs['id'] = id
|
||||
group.attrs['type'] = type
|
||||
group.attrs['speaker'] = name
|
||||
|
||||
metadata[id] = {}
|
||||
|
||||
# audio
|
||||
if audios:
|
||||
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
|
||||
|
||||
if "audio" in group:
|
||||
del group["audio"]
|
||||
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
|
||||
group.attrs['duration'] = qnt.shape[0] / 75
|
||||
metadata[id]["duration"] = qnt.shape[0] / 75
|
||||
else:
|
||||
group.attrs['duration'] = 0
|
||||
metadata[id]["duration"] = 0
|
||||
|
||||
# text
|
||||
if texts:
|
||||
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
|
||||
content = f.read()
|
||||
split = content.split(" ")
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") as f:
|
||||
content = f.read().split(" ")
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
|
||||
for s in set(phones):
|
||||
if s not in symmap:
|
||||
symmap[s] = len(symmap.keys())
|
||||
|
||||
phn = [ symmap[s] for s in phones ]
|
||||
|
||||
if "text" in group:
|
||||
del group["text"]
|
||||
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
||||
|
||||
# metadata
|
||||
group.attrs['id'] = id
|
||||
group.attrs['type'] = type
|
||||
group.attrs['speaker'] = name
|
||||
group.attrs['duration'] = qnt.shape[0] / 75
|
||||
group.attrs['phonemes'] = len(phn)
|
||||
metadata[id]["phones"] = len(phn)
|
||||
else:
|
||||
group.attrs['phonemes'] = 0
|
||||
metadata[id]["phones"] = 0
|
||||
|
||||
with open(dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
|
||||
|
||||
# training
|
||||
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
||||
|
@ -723,10 +688,9 @@ def create_dataset_hdf5():
|
|||
add( data_dir, type="noise", texts=False )
|
||||
|
||||
# write symmap
|
||||
try:
|
||||
hf.create_dataset('symmap', data=json.dumps(symmap))
|
||||
except Exception as e:
|
||||
pass
|
||||
if "symmap" in hf:
|
||||
del hf['symmap']
|
||||
hf.create_dataset('symmap', data=json.dumps(symmap))
|
||||
|
||||
hf.close()
|
||||
|
||||
|
@ -742,8 +706,16 @@ if __name__ == "__main__":
|
|||
|
||||
cfg.dataset.workers = 1
|
||||
|
||||
class LoggerOveride:
|
||||
def info(self, *args):
|
||||
print(*args)
|
||||
|
||||
_logger = LoggerOveride()
|
||||
|
||||
if args.action == "hdf5":
|
||||
create_dataset_hdf5()
|
||||
elif args.action == "metadata":
|
||||
create_dataset_metadata()
|
||||
elif args.action == "sample":
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
|
@ -758,6 +730,7 @@ if __name__ == "__main__":
|
|||
del v[i]['proms']
|
||||
del v[i]['resps']
|
||||
print(f'{k}:', v)
|
||||
|
||||
elif args.action == "tasks":
|
||||
index = 0
|
||||
cfg.dataset.tasks_list = args.tasks.split(",")
|
||||
|
|
|
@ -33,10 +33,13 @@ def _get_backend( language="en-us", backend="espeak" ):
|
|||
return phonemizer
|
||||
|
||||
|
||||
def encode(text: str, language="en-us", backend="espeak") -> list[str]:
|
||||
def encode(text: str, language="en-us", backend="auto") -> list[str]:
|
||||
if language == "en":
|
||||
language = "en-us"
|
||||
|
||||
if not backend or backend == "auto":
|
||||
backend = "espeak" # if language[:2] != "en" else "festival"
|
||||
|
||||
text = [ text ]
|
||||
|
||||
backend = _get_backend(language=language, backend=backend)
|
||||
|
@ -63,16 +66,13 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
paths = list(args.folder.rglob(f"*{args.suffix}"))
|
||||
random.shuffle(paths)
|
||||
|
||||
for path in tqdm(paths):
|
||||
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
|
||||
if phone_path.exists():
|
||||
continue
|
||||
graphs = _get_graphs(path)
|
||||
phones = encode(graphs)
|
||||
with open(phone_path, "w") as f:
|
||||
f.write(" ".join(phones))
|
||||
phones = encode(open(path, "r", encoding="utf-8").read())
|
||||
open(phone_path, "w", encoding="utf-8").write(" ".join(phones))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -213,7 +213,7 @@ def repeat_extend_audio( qnt, target ):
|
|||
pieces.append(qnt)
|
||||
length += qnt.shape[0]
|
||||
|
||||
return trim_random(torch.cat(pieces), target)
|
||||
return trim(torch.cat(pieces), target)
|
||||
|
||||
# merges two quantized audios together
|
||||
# I don't know if this works
|
||||
|
|
|
@ -2,6 +2,7 @@ from ..config import cfg
|
|||
from .base import Base, list_to_tensor, Categorical
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
@ -62,7 +63,7 @@ class AR(Base):
|
|||
max_steps: int = 1000,
|
||||
sampling_temperature: float = 1.0,
|
||||
|
||||
naive: bool = True,
|
||||
naive: bool = False,
|
||||
):
|
||||
if resps_list is not None:
|
||||
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
||||
|
@ -83,30 +84,104 @@ class AR(Base):
|
|||
]
|
||||
stopped = torch.zeros(len(text_list), device=device).bool()
|
||||
|
||||
chunk_size = 1 # don't really know what to do about this desu
|
||||
chunk_size = self.causal_chunk_size # don't really know what to do about this desu
|
||||
|
||||
state = None
|
||||
start = 0
|
||||
|
||||
for n in trange(max_steps // chunk_size):
|
||||
# get next in sequence
|
||||
if naive:
|
||||
for n in trange(max_steps // max(1, chunk_size)):
|
||||
# get next in sequence
|
||||
|
||||
r, state = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
self._unsqueeze_list(resps_list),
|
||||
sampling_temperature=sampling_temperature,
|
||||
state=state if not naive else None,
|
||||
r, state = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
self._unsqueeze_list(resps_list),
|
||||
sampling_temperature=sampling_temperature,
|
||||
state=state # if not naive else None,
|
||||
)
|
||||
|
||||
# append outputted token
|
||||
if self.causal_chunk_size > 0:
|
||||
for i, ri in enumerate(r):
|
||||
resps_list[i] = torch.cat([resps_list[i], ri])
|
||||
else:
|
||||
for i, ri in enumerate(r):
|
||||
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
||||
|
||||
|
||||
# stop token found
|
||||
stopped |= r == self.stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
# to-do: make it work
|
||||
# it seems anything that isn't a one-at-a-time sequence does not work, despite generating STOP tokens.
|
||||
else:
|
||||
resps_list: list[Tensor] = [
|
||||
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
||||
]
|
||||
|
||||
test_list: list[Tensor] = [
|
||||
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
||||
]
|
||||
|
||||
batch_size = len(text_list)
|
||||
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb(self._unsqueeze_list(resps_list)),
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
x, m = list_to_tensor(x_list)
|
||||
device = x.device
|
||||
|
||||
if state is None:
|
||||
state = {}
|
||||
|
||||
# pre-fill KV cache
|
||||
for n in trange(x.shape[1]):
|
||||
xs = x[:, n:(n + 1), :]
|
||||
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
|
||||
r = self.classifier(r) * m
|
||||
|
||||
logits = torch.stack([hi[-1] for hi in r])
|
||||
r = Categorical(logits=logits / sampling_temperature).sample()
|
||||
|
||||
for i, ri in enumerate(r):
|
||||
test_list[i] = torch.cat([test_list[i], ri[None]])
|
||||
|
||||
# append outputted token
|
||||
for i, ri in enumerate(r):
|
||||
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == self.stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
start = x.shape[1]
|
||||
for n in trange(max_steps // max(1, chunk_size)):
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb(self._unsqueeze_list(resps_list)),
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
xs = x[:, start+n:start+(n+1), :]
|
||||
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
|
||||
r = self.classifier(r) * m
|
||||
|
||||
logits = torch.stack([hi[-1] for hi in r])
|
||||
r = Categorical(logits=logits / sampling_temperature).sample()
|
||||
|
||||
# append outputted token
|
||||
for i, ri in enumerate(r):
|
||||
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == self.stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
pruned = [self._prune(r) for r in resps_list]
|
||||
return pruned
|
||||
|
|
|
@ -149,6 +149,8 @@ class Base(nn.Module):
|
|||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
self.causal_chunk_size = 0 # 64 if self.causal else 1
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
self.sin_emb = SinusoidalEmbedding(d_model)
|
||||
|
@ -171,8 +173,8 @@ class Base(nn.Module):
|
|||
dropout=p_dropout,
|
||||
checkpoint_activations=True,
|
||||
|
||||
chunkwise_recurrent=False, # self.causal,
|
||||
recurrent_chunkwise_size=64,
|
||||
chunkwise_recurrent=self.causal and self.causal_chunk_size > 0,
|
||||
recurrent_chunkwise_size=self.causal_chunk_size,
|
||||
no_output_layer=True,
|
||||
decoder_normalize_before=True,
|
||||
))
|
||||
|
@ -358,6 +360,10 @@ class Base(nn.Module):
|
|||
elif return_all_resp:
|
||||
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
||||
# return the last chunkwise piece
|
||||
elif self.causal_chunk_size > 0:
|
||||
logits = [hi[-self.causal_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
||||
# return just the last code
|
||||
else:
|
||||
logits = torch.stack([hi[-1] for hi in h_list])
|
||||
|
|
Loading…
Reference in New Issue
Block a user