overhauled dataloading code to be marginally faster, mostly cleaned up, and can leverage a metadata json to help things out
This commit is contained in:
parent
7b3be3d7bf
commit
78378ed1ce
|
@ -27,6 +27,10 @@ class _Config:
|
||||||
def relpath(self):
|
def relpath(self):
|
||||||
return Path(self.cfg_path)
|
return Path(self.cfg_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_dir(self):
|
||||||
|
return self.relpath / ".cache"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ckpt_dir(self):
|
def ckpt_dir(self):
|
||||||
return self.relpath / "ckpt"
|
return self.relpath / "ckpt"
|
||||||
|
@ -119,6 +123,7 @@ class Dataset:
|
||||||
|
|
||||||
hdf5_name: str = "data.h5"
|
hdf5_name: str = "data.h5"
|
||||||
use_hdf5: bool = False
|
use_hdf5: bool = False
|
||||||
|
use_metadata: bool = False
|
||||||
hdf5_flag: str = "a"
|
hdf5_flag: str = "a"
|
||||||
validate: bool = True
|
validate: bool = True
|
||||||
workers: int = 8
|
workers: int = 8
|
||||||
|
@ -135,6 +140,19 @@ class Dataset:
|
||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def min_phones(self):
|
||||||
|
return self.phones_range[0]
|
||||||
|
@property
|
||||||
|
def max_phones(self):
|
||||||
|
return self.phones_range[1]
|
||||||
|
@property
|
||||||
|
def min_duration(self):
|
||||||
|
return self.duration_range[0]
|
||||||
|
@property
|
||||||
|
def max_duration(self):
|
||||||
|
return self.duration_range[1]
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
name: str = ""
|
name: str = ""
|
||||||
|
@ -393,7 +411,7 @@ class Trainer:
|
||||||
|
|
||||||
weight_dtype: str = "float16"
|
weight_dtype: str = "float16"
|
||||||
|
|
||||||
backend: str = "deepspeed" if not sys.platform.startswith("win") else "local"
|
backend: str = "local" # "deepspeed" if not sys.platform.startswith("win") else "local"
|
||||||
|
|
||||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||||
|
|
||||||
|
@ -453,10 +471,6 @@ class Config(_Config):
|
||||||
def get_spkr(self):
|
def get_spkr(self):
|
||||||
return eval(self.dataset.speaker_name_getter)
|
return eval(self.dataset.speaker_name_getter)
|
||||||
|
|
||||||
@property
|
|
||||||
def cache_dir(self):
|
|
||||||
return ".cache" / self.relpath
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def diskcache(self):
|
def diskcache(self):
|
||||||
if self.cfg_path is not None and self.dataset.cache:
|
if self.cfg_path is not None and self.dataset.cache:
|
||||||
|
@ -501,11 +515,10 @@ try:
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
cfg.load_hdf5()
|
cfg.load_hdf5()
|
||||||
|
|
||||||
if not cfg.dataset.use_hdf5:
|
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
||||||
cfg.dataset.training = [ Path(dir) for dir in cfg.dataset.training ]
|
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
||||||
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
|
|
||||||
|
|
||||||
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
|
cfg.dataset.noise = [ Path(dir) for dir in cfg.dataset.noise ]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
417
vall_e/data.py
417
vall_e/data.py
|
@ -8,6 +8,7 @@ import numpy as np
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
|
import itertools
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||||
|
@ -31,7 +32,7 @@ def get_phone_symmap():
|
||||||
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
|
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
|
||||||
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
|
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
|
||||||
|
|
||||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185}
|
||||||
return symmap
|
return symmap
|
||||||
|
|
||||||
def get_task_symmap():
|
def get_task_symmap():
|
||||||
|
@ -51,24 +52,89 @@ def get_task_symmap():
|
||||||
def _replace_file_extension(path, suffix):
|
def _replace_file_extension(path, suffix):
|
||||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||||
|
|
||||||
def _get_hdf5_path(path):
|
|
||||||
path = str(path)
|
|
||||||
if path[:2] != "./":
|
|
||||||
path = f'./{path}'
|
|
||||||
return path.replace(cfg.cfg_path, "")
|
|
||||||
|
|
||||||
def _get_quant_path(path):
|
def _get_quant_path(path):
|
||||||
return _replace_file_extension(path, ".qnt.pt")
|
return _replace_file_extension(path, ".qnt.pt")
|
||||||
|
|
||||||
def _get_phone_path(path):
|
def _get_phone_path(path):
|
||||||
return _replace_file_extension(path, ".phn.txt")
|
return _replace_file_extension(path, ".phn.txt")
|
||||||
|
|
||||||
|
def _load_paths(dataset, type="training"):
|
||||||
|
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||||
|
|
||||||
|
"""
|
||||||
|
def _load_paths_from_hdf5(dataset, type="training"):
|
||||||
|
return { cfg.get_spkr( data_dir / "dummy" ): _get_hdf5_paths( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||||
|
|
||||||
|
def _load_paths_from_disk(dataset, type="training"):
|
||||||
|
return { cfg.get_spkr( data_dir / "dummy" ): _get_paths_of_extensions( data_dir, ".qnt.pt", validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _load_paths_from_metadata(data_dir, type="training", validate=False):
|
||||||
|
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
||||||
|
|
||||||
|
def _validate( entry ):
|
||||||
|
phones = entry['phones']
|
||||||
|
duration = entry['duration']
|
||||||
|
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||||
|
|
||||||
|
metadata_path = data_dir / "metadata.json"
|
||||||
|
if not cfg.dataset.use_metadata or not metadata_path.exists():
|
||||||
|
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
|
||||||
|
|
||||||
|
speaker = cfg.get_spkr( data_dir / "dummy" )
|
||||||
|
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||||
|
|
||||||
|
def key( dir, id ):
|
||||||
|
if not cfg.dataset.use_hdf5:
|
||||||
|
return data_dir / id
|
||||||
|
|
||||||
|
return f"/{type}{_get_hdf5_path(data_dir)}/{id}"
|
||||||
|
|
||||||
|
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_hdf5_path(path):
|
||||||
|
path = str(path)
|
||||||
|
if path[:2] != "./":
|
||||||
|
path = f'./{path}'
|
||||||
|
return path.replace(cfg.cfg_path, "")
|
||||||
|
|
||||||
|
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||||
|
data_dir = str(data_dir)
|
||||||
|
|
||||||
|
def _validate(child):
|
||||||
|
phones = child.attrs['phonemes']
|
||||||
|
duration = child.attrs['duration']
|
||||||
|
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||||
|
|
||||||
|
key = f"/{type}{_get_hdf5_path(data_dir)}"
|
||||||
|
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
|
||||||
|
|
||||||
|
def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ):
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
def _validate(path):
|
||||||
|
if "".join(path.suffixes) not in extensions:
|
||||||
|
return False
|
||||||
|
if not _get_phone_path(path).exists() or not _get_quant_path(path).exists():
|
||||||
|
return False
|
||||||
|
if not validate:
|
||||||
|
return True
|
||||||
|
# to-do: find an easy way to determine size from pickled quants without loading
|
||||||
|
# to-do: find a consistent way to derive phoneme count from filesize (probably can't due to utf-8)
|
||||||
|
phones = len(_get_phones(_get_phone_path(path))) # _get_phone_path(path).stat().st_size // 2 + 1
|
||||||
|
return cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||||
|
|
||||||
|
|
||||||
|
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
|
||||||
|
|
||||||
def _load_quants(path) -> Tensor:
|
def _load_quants(path) -> Tensor:
|
||||||
return torch.load(path)[0][:, :].t().to(torch.int16)
|
return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _get_phones(path, language="en"):
|
def _get_phones(path, language="en"):
|
||||||
content = open(_get_phone_path(path), "r", encoding="utf8").read().split(" ")
|
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
|
||||||
return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
|
return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
|
||||||
|
|
||||||
def _interleaved_reorder(l, fn):
|
def _interleaved_reorder(l, fn):
|
||||||
|
@ -81,114 +147,61 @@ def _interleaved_reorder(l, fn):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def _validate(path, min_phones, max_phones, min_duration, max_duration):
|
|
||||||
if cfg.dataset.use_hdf5:
|
|
||||||
key = _get_hdf5_path(path)
|
|
||||||
if key not in cfg.hdf5:
|
|
||||||
return False
|
|
||||||
|
|
||||||
phones = cfg.hdf5[key].attrs['phonemes']
|
|
||||||
duration = cfg.hdf5[key].attrs['duration']
|
|
||||||
|
|
||||||
if phones < min_phones or phones > max_phones:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if duration < min_duration or duration > max_duration:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
if not os.path.exists(_get_phone_path(path)) or not os.path.exists(_get_quant_path(path)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
phones = _get_phones(path)
|
|
||||||
unique_phones = list(set(phones))
|
|
||||||
|
|
||||||
if len(unique_phones) == 0:
|
|
||||||
return False
|
|
||||||
if len(unique_phones) == 1 and unique_phones[0] == " ":
|
|
||||||
return False
|
|
||||||
if len(phones) < min_phones or len(phones) > max_phones:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
class Dataset(_Dataset):
|
class Dataset(_Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
paths,
|
|
||||||
phone_symmap=None,
|
phone_symmap=None,
|
||||||
training=False,
|
training=False,
|
||||||
extra_paths_by_spkr_name: dict[str, list] = {},
|
extra_paths_by_spkr_name: dict[str, list] = {},
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._head = None
|
self._head = None
|
||||||
self.min_phones = cfg.dataset.phones_range[0]
|
|
||||||
self.max_phones = cfg.dataset.phones_range[1]
|
|
||||||
self.min_duration = cfg.dataset.duration_range[0]
|
|
||||||
self.max_duration = cfg.dataset.duration_range[1]
|
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
|
|
||||||
if cfg.dataset.validate:
|
self.paths = []
|
||||||
self.paths = [
|
|
||||||
path for path in paths if _validate(path, self.min_phones, self.max_phones, self.min_duration, self.max_duration)
|
self.training = training
|
||||||
]
|
self.dataset_type = "training" if self.training else "validation"
|
||||||
else:
|
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||||
self.paths = paths
|
|
||||||
|
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
||||||
|
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||||
|
|
||||||
|
if cfg.dataset.sample_type == "path":
|
||||||
|
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||||
|
|
||||||
|
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
||||||
|
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
|
||||||
|
|
||||||
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
||||||
self.spkr_symmap = self._get_spkr_symmap()
|
self.spkr_symmap = self._get_spkr_symmap()
|
||||||
self.task_symmap = self._get_task_symmap()
|
self.task_symmap = self._get_task_symmap()
|
||||||
self.training = training
|
|
||||||
|
|
||||||
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
||||||
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
||||||
|
|
||||||
self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
|
|
||||||
|
|
||||||
if cfg.dataset.validate:
|
|
||||||
self.paths = [
|
|
||||||
p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
|
|
||||||
]
|
|
||||||
|
|
||||||
if cfg.dataset.sample_type == "path":
|
|
||||||
self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)]
|
|
||||||
|
|
||||||
if len(self.paths) == 0 and training:
|
if len(self.paths) == 0 and training:
|
||||||
raise ValueError("No valid path is found for training.")
|
raise ValueError("No valid path is found for training.")
|
||||||
|
|
||||||
|
# would be a better cost saving if we could fetch the duration during the validation pass but oh well
|
||||||
self.duration = 0
|
self.duration = 0
|
||||||
self.durations = {}
|
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
for path in self.paths:
|
for path in self.paths:
|
||||||
key = _get_hdf5_path(path)
|
self.duration += cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
|
||||||
spkr_name = cfg.get_spkr(path)
|
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
|
||||||
duration = cfg.hdf5[key].attrs['duration']
|
|
||||||
|
|
||||||
self.duration += duration
|
|
||||||
|
|
||||||
if spkr_id not in self.durations:
|
|
||||||
self.durations[spkr_id] = duration
|
|
||||||
else:
|
|
||||||
self.durations[spkr_id] += duration
|
|
||||||
|
|
||||||
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
|
|
||||||
ret = defaultdict(list)
|
|
||||||
for path in self.paths:
|
|
||||||
ret[cfg.get_spkr(path)].append(path)
|
|
||||||
for k, v in extra_paths_by_spkr_name.items():
|
|
||||||
ret[k].extend(v)
|
|
||||||
return {**ret}
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def phones(self):
|
def phones(self):
|
||||||
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
||||||
|
|
||||||
|
def get_speaker(self, path):
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
res = cfg.get_spkr(path)
|
||||||
|
return res
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def spkrs(self):
|
def spkrs(self):
|
||||||
return sorted({cfg.get_spkr(path) for path in self.paths})
|
return sorted({self.get_speaker(path) for path in self.paths})
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def tasks(self):
|
def tasks(self):
|
||||||
|
@ -209,13 +222,10 @@ class Dataset(_Dataset):
|
||||||
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
|
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
|
||||||
|
|
||||||
def sample_noise(self):
|
def sample_noise(self):
|
||||||
paths = []
|
path = random.choice(self.noise_paths)
|
||||||
for data_dir in cfg.dataset.noise:
|
|
||||||
paths.extend(data_dir.rglob("*.qnt.pt"))
|
|
||||||
path = random.choice(paths)
|
|
||||||
|
|
||||||
if False and cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = f'/noise/{_get_hdf5_path(path)}'
|
key = _get_hdf5_path(path)
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
qnt = _load_quants(path)
|
qnt = _load_quants(path)
|
||||||
|
@ -275,7 +285,7 @@ class Dataset(_Dataset):
|
||||||
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
||||||
else:
|
else:
|
||||||
path = self.paths[index]
|
path = self.paths[index]
|
||||||
spkr_name = cfg.get_spkr(path)
|
spkr_name = self.get_speaker(path)
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
spkr_id = self.spkr_symmap[spkr_name]
|
||||||
|
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
|
@ -507,110 +517,10 @@ def _create_dataloader(dataset, training):
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_dataset_paths():
|
|
||||||
hf = cfg.hdf5
|
|
||||||
paths = {
|
|
||||||
"training": [],
|
|
||||||
"validation": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
datasets = {
|
|
||||||
"training": [],
|
|
||||||
"validation": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_paths( data_dir, type="training" ):
|
|
||||||
key = f"/{type}{_get_hdf5_path(data_dir)}"
|
|
||||||
if key not in cfg.hdf5:
|
|
||||||
return
|
|
||||||
|
|
||||||
paths[type].extend([ f"{key}/{child.attrs['id']}" for child in cfg.hdf5[key].values() ])
|
|
||||||
|
|
||||||
for data_dir in cfg.dataset.training:
|
|
||||||
get_paths( data_dir, "training" )
|
|
||||||
|
|
||||||
for data_dir in cfg.dataset.validation:
|
|
||||||
get_paths( data_dir, "validation" )
|
|
||||||
|
|
||||||
for _, type in enumerate(paths):
|
|
||||||
dirs = paths[type]
|
|
||||||
|
|
||||||
if len(dirs) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
dirs = [ Path(p) for p in dirs ]
|
|
||||||
|
|
||||||
pairs = sorted([(cfg.get_spkr(p), p) for p in dirs])
|
|
||||||
for _, group in groupby(pairs, lambda pair: pair[0]):
|
|
||||||
shuffled = sorted([p for _, p in group])
|
|
||||||
random.seed(0)
|
|
||||||
random.shuffle(shuffled)
|
|
||||||
|
|
||||||
datasets[type].extend(shuffled)
|
|
||||||
|
|
||||||
return datasets["training"], datasets["validation"]
|
|
||||||
|
|
||||||
# to-do: mirror the hdf5-based load function
|
|
||||||
def _load_train_val_paths():
|
|
||||||
paths = []
|
|
||||||
train_paths = []
|
|
||||||
val_paths = []
|
|
||||||
|
|
||||||
for data_dir in cfg.dataset.training:
|
|
||||||
paths.extend(data_dir.rglob("*.qnt.pt"))
|
|
||||||
|
|
||||||
if len(paths) > 0:
|
|
||||||
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
|
|
||||||
del paths
|
|
||||||
|
|
||||||
for _, group in groupby(pairs, lambda pair: pair[0]):
|
|
||||||
paths = sorted([p for _, p in group])
|
|
||||||
random.seed(0)
|
|
||||||
random.shuffle(paths)
|
|
||||||
train_paths.extend(paths)
|
|
||||||
|
|
||||||
for data_dir in cfg.dataset.validation:
|
|
||||||
paths.extend(data_dir.rglob("*.qnt.pt"))
|
|
||||||
|
|
||||||
if len(paths) > 0:
|
|
||||||
pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
|
|
||||||
del paths
|
|
||||||
|
|
||||||
for _, group in groupby(pairs, lambda pair: pair[0]):
|
|
||||||
paths = sorted([p for _, p in group])
|
|
||||||
random.seed(0)
|
|
||||||
random.shuffle(paths)
|
|
||||||
val_paths.extend(paths)
|
|
||||||
|
|
||||||
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
|
||||||
|
|
||||||
if len(train_paths) == 0:
|
|
||||||
raise RuntimeError(f"Failed to find any .qnt.pt file in specified training dataset.")
|
|
||||||
|
|
||||||
# to-do: allow setting aside a fixed portion of the training dataset for validation
|
|
||||||
# something like the last X percent of each speaker is set aside
|
|
||||||
if len(val_paths) == 0:
|
|
||||||
val_paths = [ train_paths[0] ]
|
|
||||||
|
|
||||||
return train_paths, val_paths
|
|
||||||
|
|
||||||
@cfg.diskcache()
|
@cfg.diskcache()
|
||||||
def create_datasets():
|
def create_datasets():
|
||||||
train_paths, val_paths = _load_dataset_paths() if cfg.dataset.use_hdf5 else _load_train_val_paths()
|
train_dataset = Dataset( training=True )
|
||||||
|
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
||||||
train_dataset = Dataset(
|
|
||||||
train_paths,
|
|
||||||
training=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
val_dataset = Dataset(
|
|
||||||
val_paths,
|
|
||||||
train_dataset.phone_symmap,
|
|
||||||
#train_dataset.spkr_symmap,
|
|
||||||
#extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
val_dataset.head_(cfg.evaluation.size)
|
|
||||||
|
|
||||||
return train_dataset, val_dataset
|
return train_dataset, val_dataset
|
||||||
|
|
||||||
|
@ -628,17 +538,10 @@ def create_train_val_dataloader():
|
||||||
|
|
||||||
_logger.info(str(train_dataset.phone_symmap))
|
_logger.info(str(train_dataset.phone_symmap))
|
||||||
_logger.info(str(train_dataset.spkr_symmap))
|
_logger.info(str(train_dataset.spkr_symmap))
|
||||||
|
|
||||||
|
|
||||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||||
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
|
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
|
||||||
|
|
||||||
"""
|
|
||||||
_logger.info(f"#durations (train): {str(train_dataset.durations)}.")
|
|
||||||
_logger.info(f"#durations (val): {str(val_dataset.durations)}.")
|
|
||||||
_logger.info(f"#durations (subtrain): {str(subtrain_dataset.durations)}.")
|
|
||||||
"""
|
|
||||||
|
|
||||||
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
||||||
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
||||||
|
@ -648,8 +551,52 @@ def create_train_val_dataloader():
|
||||||
|
|
||||||
return train_dl, subtrain_dl, val_dl
|
return train_dl, subtrain_dl, val_dl
|
||||||
|
|
||||||
|
# parse dataset into better to sample metadata
|
||||||
|
def create_dataset_metadata():
|
||||||
|
cfg.dataset.validate = False
|
||||||
|
cfg.dataset.use_hdf5 = False
|
||||||
|
|
||||||
|
paths_by_spkr_name = {}
|
||||||
|
|
||||||
|
paths_by_spkr_name |= _load_paths(cfg.dataset.training, "training")
|
||||||
|
paths_by_spkr_name |= _load_paths(cfg.dataset.validation, "validation")
|
||||||
|
paths_by_spkr_name |= _load_paths(cfg.dataset.noise, "noise")
|
||||||
|
|
||||||
|
paths = list(itertools.chain.from_iterable(paths_by_spkr_name.values()))
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
for path in tqdm(paths, desc="Parsing paths"):
|
||||||
|
speaker = cfg.get_spkr(path)
|
||||||
|
if speaker not in metadata:
|
||||||
|
metadata[speaker] = {}
|
||||||
|
|
||||||
|
if cfg.dataset.use_hdf5:
|
||||||
|
phones = cfg.hdf5[_get_hdf5_path(path)].attrs['phonemes']
|
||||||
|
duration = cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
|
||||||
|
else:
|
||||||
|
phns_path = _get_phone_path(path)
|
||||||
|
qnts_path = _get_quant_path(path)
|
||||||
|
|
||||||
|
phones = len(_get_phones(phns_path)) if phns_path.exists() else 0
|
||||||
|
duration = _load_quants(qnts_path).shape[0] / 75 if qnts_path.exists() else 0
|
||||||
|
|
||||||
|
|
||||||
|
metadata[speaker][path.name.split(".")[0]] = {
|
||||||
|
"phones": phones,
|
||||||
|
"duration": duration
|
||||||
|
}
|
||||||
|
|
||||||
|
for speaker, paths in tqdm(paths_by_spkr_name.items(), desc="Writing metadata"):
|
||||||
|
if len(paths) == 0:
|
||||||
|
continue
|
||||||
|
with open(paths[0].parent / "metadata.json", "w", encoding="utf-8") as f:
|
||||||
|
f.write( json.dumps( metadata[speaker] ) )
|
||||||
|
|
||||||
|
with open(cfg.relpath / "metadata.json", "w", encoding="utf-8") as f:
|
||||||
|
f.write( json.dumps( metadata ) )
|
||||||
|
|
||||||
# parse yaml to create an hdf5 file
|
# parse yaml to create an hdf5 file
|
||||||
def create_dataset_hdf5():
|
def create_dataset_hdf5( skip_existing=True ):
|
||||||
cfg.dataset.use_hdf5 = True
|
cfg.dataset.use_hdf5 = True
|
||||||
cfg.load_hdf5(write=True)
|
cfg.load_hdf5(write=True)
|
||||||
|
|
||||||
|
@ -658,11 +605,11 @@ def create_dataset_hdf5():
|
||||||
root = cfg.cfg_path
|
root = cfg.cfg_path
|
||||||
hf = cfg.hdf5
|
hf = cfg.hdf5
|
||||||
|
|
||||||
def add( dir, type="training", audios=True, texts=True ):
|
|
||||||
dir = "./" + str(dir)
|
|
||||||
name = dir.replace(root, "")
|
|
||||||
|
|
||||||
print( str(dir), name )
|
def add( dir, type="training", audios=True, texts=True ):
|
||||||
|
name = "./" + str(dir)
|
||||||
|
name = name .replace(root, "")
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
if not os.path.isdir(f'{root}/{name}/'):
|
if not os.path.isdir(f'{root}/{name}/'):
|
||||||
return
|
return
|
||||||
|
@ -680,35 +627,53 @@ def create_dataset_hdf5():
|
||||||
|
|
||||||
key = f'{type}/{name}/{id}'
|
key = f'{type}/{name}/{id}'
|
||||||
if key in hf:
|
if key in hf:
|
||||||
# print("Skipping existing entry:", key)
|
if skip_existing:
|
||||||
continue
|
continue
|
||||||
|
del hf[key]
|
||||||
|
|
||||||
group = hf.create_group(key)
|
group = hf.create_group(key)
|
||||||
|
group.attrs['id'] = id
|
||||||
|
group.attrs['type'] = type
|
||||||
|
group.attrs['speaker'] = name
|
||||||
|
|
||||||
|
metadata[id] = {}
|
||||||
|
|
||||||
# audio
|
# audio
|
||||||
if audios:
|
if audios:
|
||||||
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
|
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
|
||||||
|
|
||||||
|
if "audio" in group:
|
||||||
|
del group["audio"]
|
||||||
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
|
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
|
||||||
|
group.attrs['duration'] = qnt.shape[0] / 75
|
||||||
|
metadata[id]["duration"] = qnt.shape[0] / 75
|
||||||
|
else:
|
||||||
|
group.attrs['duration'] = 0
|
||||||
|
metadata[id]["duration"] = 0
|
||||||
|
|
||||||
# text
|
# text
|
||||||
if texts:
|
if texts:
|
||||||
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf8") as f:
|
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read().split(" ")
|
||||||
split = content.split(" ")
|
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
|
||||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
|
||||||
for s in set(phones):
|
for s in set(phones):
|
||||||
if s not in symmap:
|
if s not in symmap:
|
||||||
symmap[s] = len(symmap.keys())
|
symmap[s] = len(symmap.keys())
|
||||||
|
|
||||||
phn = [ symmap[s] for s in phones ]
|
phn = [ symmap[s] for s in phones ]
|
||||||
|
|
||||||
|
if "text" in group:
|
||||||
|
del group["text"]
|
||||||
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
||||||
|
|
||||||
# metadata
|
|
||||||
group.attrs['id'] = id
|
|
||||||
group.attrs['type'] = type
|
|
||||||
group.attrs['speaker'] = name
|
|
||||||
group.attrs['duration'] = qnt.shape[0] / 75
|
|
||||||
group.attrs['phonemes'] = len(phn)
|
group.attrs['phonemes'] = len(phn)
|
||||||
|
metadata[id]["phones"] = len(phn)
|
||||||
|
else:
|
||||||
|
group.attrs['phonemes'] = 0
|
||||||
|
metadata[id]["phones"] = 0
|
||||||
|
|
||||||
|
with open(dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||||
|
f.write( json.dumps( metadata ) )
|
||||||
|
|
||||||
|
|
||||||
# training
|
# training
|
||||||
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
for data_dir in tqdm(cfg.dataset.training, desc="Processing Training"):
|
||||||
|
@ -723,10 +688,9 @@ def create_dataset_hdf5():
|
||||||
add( data_dir, type="noise", texts=False )
|
add( data_dir, type="noise", texts=False )
|
||||||
|
|
||||||
# write symmap
|
# write symmap
|
||||||
try:
|
if "symmap" in hf:
|
||||||
hf.create_dataset('symmap', data=json.dumps(symmap))
|
del hf['symmap']
|
||||||
except Exception as e:
|
hf.create_dataset('symmap', data=json.dumps(symmap))
|
||||||
pass
|
|
||||||
|
|
||||||
hf.close()
|
hf.close()
|
||||||
|
|
||||||
|
@ -742,8 +706,16 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
cfg.dataset.workers = 1
|
cfg.dataset.workers = 1
|
||||||
|
|
||||||
|
class LoggerOveride:
|
||||||
|
def info(self, *args):
|
||||||
|
print(*args)
|
||||||
|
|
||||||
|
_logger = LoggerOveride()
|
||||||
|
|
||||||
if args.action == "hdf5":
|
if args.action == "hdf5":
|
||||||
create_dataset_hdf5()
|
create_dataset_hdf5()
|
||||||
|
elif args.action == "metadata":
|
||||||
|
create_dataset_metadata()
|
||||||
elif args.action == "sample":
|
elif args.action == "sample":
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
|
||||||
|
@ -758,6 +730,7 @@ if __name__ == "__main__":
|
||||||
del v[i]['proms']
|
del v[i]['proms']
|
||||||
del v[i]['resps']
|
del v[i]['resps']
|
||||||
print(f'{k}:', v)
|
print(f'{k}:', v)
|
||||||
|
|
||||||
elif args.action == "tasks":
|
elif args.action == "tasks":
|
||||||
index = 0
|
index = 0
|
||||||
cfg.dataset.tasks_list = args.tasks.split(",")
|
cfg.dataset.tasks_list = args.tasks.split(",")
|
||||||
|
|
|
@ -33,10 +33,13 @@ def _get_backend( language="en-us", backend="espeak" ):
|
||||||
return phonemizer
|
return phonemizer
|
||||||
|
|
||||||
|
|
||||||
def encode(text: str, language="en-us", backend="espeak") -> list[str]:
|
def encode(text: str, language="en-us", backend="auto") -> list[str]:
|
||||||
if language == "en":
|
if language == "en":
|
||||||
language = "en-us"
|
language = "en-us"
|
||||||
|
|
||||||
|
if not backend or backend == "auto":
|
||||||
|
backend = "espeak" # if language[:2] != "en" else "festival"
|
||||||
|
|
||||||
text = [ text ]
|
text = [ text ]
|
||||||
|
|
||||||
backend = _get_backend(language=language, backend=backend)
|
backend = _get_backend(language=language, backend=backend)
|
||||||
|
@ -63,16 +66,13 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
paths = list(args.folder.rglob(f"*{args.suffix}"))
|
paths = list(args.folder.rglob(f"*{args.suffix}"))
|
||||||
random.shuffle(paths)
|
|
||||||
|
|
||||||
for path in tqdm(paths):
|
for path in tqdm(paths):
|
||||||
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
|
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
|
||||||
if phone_path.exists():
|
if phone_path.exists():
|
||||||
continue
|
continue
|
||||||
graphs = _get_graphs(path)
|
phones = encode(open(path, "r", encoding="utf-8").read())
|
||||||
phones = encode(graphs)
|
open(phone_path, "w", encoding="utf-8").write(" ".join(phones))
|
||||||
with open(phone_path, "w") as f:
|
|
||||||
f.write(" ".join(phones))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -213,7 +213,7 @@ def repeat_extend_audio( qnt, target ):
|
||||||
pieces.append(qnt)
|
pieces.append(qnt)
|
||||||
length += qnt.shape[0]
|
length += qnt.shape[0]
|
||||||
|
|
||||||
return trim_random(torch.cat(pieces), target)
|
return trim(torch.cat(pieces), target)
|
||||||
|
|
||||||
# merges two quantized audios together
|
# merges two quantized audios together
|
||||||
# I don't know if this works
|
# I don't know if this works
|
||||||
|
|
|
@ -2,6 +2,7 @@ from ..config import cfg
|
||||||
from .base import Base, list_to_tensor, Categorical
|
from .base import Base, list_to_tensor, Categorical
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -62,7 +63,7 @@ class AR(Base):
|
||||||
max_steps: int = 1000,
|
max_steps: int = 1000,
|
||||||
sampling_temperature: float = 1.0,
|
sampling_temperature: float = 1.0,
|
||||||
|
|
||||||
naive: bool = True,
|
naive: bool = False,
|
||||||
):
|
):
|
||||||
if resps_list is not None:
|
if resps_list is not None:
|
||||||
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
||||||
|
@ -83,30 +84,104 @@ class AR(Base):
|
||||||
]
|
]
|
||||||
stopped = torch.zeros(len(text_list), device=device).bool()
|
stopped = torch.zeros(len(text_list), device=device).bool()
|
||||||
|
|
||||||
chunk_size = 1 # don't really know what to do about this desu
|
chunk_size = self.causal_chunk_size # don't really know what to do about this desu
|
||||||
|
|
||||||
state = None
|
state = None
|
||||||
start = 0
|
start = 0
|
||||||
|
|
||||||
for n in trange(max_steps // chunk_size):
|
if naive:
|
||||||
# get next in sequence
|
for n in trange(max_steps // max(1, chunk_size)):
|
||||||
|
# get next in sequence
|
||||||
|
|
||||||
r, state = super().forward(
|
r, state = super().forward(
|
||||||
text_list,
|
text_list,
|
||||||
proms_list,
|
proms_list,
|
||||||
self._unsqueeze_list(resps_list),
|
self._unsqueeze_list(resps_list),
|
||||||
sampling_temperature=sampling_temperature,
|
sampling_temperature=sampling_temperature,
|
||||||
state=state if not naive else None,
|
state=state # if not naive else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# append outputted token
|
||||||
|
if self.causal_chunk_size > 0:
|
||||||
|
for i, ri in enumerate(r):
|
||||||
|
resps_list[i] = torch.cat([resps_list[i], ri])
|
||||||
|
else:
|
||||||
|
for i, ri in enumerate(r):
|
||||||
|
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
||||||
|
|
||||||
|
|
||||||
|
# stop token found
|
||||||
|
stopped |= r == self.stop_token
|
||||||
|
if stopped.all().item():
|
||||||
|
break
|
||||||
|
# to-do: make it work
|
||||||
|
# it seems anything that isn't a one-at-a-time sequence does not work, despite generating STOP tokens.
|
||||||
|
else:
|
||||||
|
resps_list: list[Tensor] = [
|
||||||
|
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
||||||
|
]
|
||||||
|
|
||||||
|
test_list: list[Tensor] = [
|
||||||
|
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
||||||
|
]
|
||||||
|
|
||||||
|
batch_size = len(text_list)
|
||||||
|
|
||||||
|
x_list = self._samplewise_merge_tensors(
|
||||||
|
self.text_emb(text_list),
|
||||||
|
self.proms_emb(proms_list),
|
||||||
|
self.resps_emb(self._unsqueeze_list(resps_list)),
|
||||||
|
sep=self.sep,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
x, m = list_to_tensor(x_list)
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
state = {}
|
||||||
|
|
||||||
|
# pre-fill KV cache
|
||||||
|
for n in trange(x.shape[1]):
|
||||||
|
xs = x[:, n:(n + 1), :]
|
||||||
|
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
|
||||||
|
r = self.classifier(r) * m
|
||||||
|
|
||||||
|
logits = torch.stack([hi[-1] for hi in r])
|
||||||
|
r = Categorical(logits=logits / sampling_temperature).sample()
|
||||||
|
|
||||||
|
for i, ri in enumerate(r):
|
||||||
|
test_list[i] = torch.cat([test_list[i], ri[None]])
|
||||||
|
|
||||||
# append outputted token
|
# append outputted token
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
||||||
|
|
||||||
# stop token found
|
start = x.shape[1]
|
||||||
stopped |= r == self.stop_token
|
for n in trange(max_steps // max(1, chunk_size)):
|
||||||
if stopped.all().item():
|
x_list = self._samplewise_merge_tensors(
|
||||||
break
|
self.text_emb(text_list),
|
||||||
|
self.proms_emb(proms_list),
|
||||||
|
self.resps_emb(self._unsqueeze_list(resps_list)),
|
||||||
|
sep=self.sep,
|
||||||
|
)
|
||||||
|
|
||||||
|
x, m = list_to_tensor(x_list)
|
||||||
|
|
||||||
|
xs = x[:, start+n:start+(n+1), :]
|
||||||
|
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
|
||||||
|
r = self.classifier(r) * m
|
||||||
|
|
||||||
|
logits = torch.stack([hi[-1] for hi in r])
|
||||||
|
r = Categorical(logits=logits / sampling_temperature).sample()
|
||||||
|
|
||||||
|
# append outputted token
|
||||||
|
for i, ri in enumerate(r):
|
||||||
|
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
||||||
|
|
||||||
|
# stop token found
|
||||||
|
stopped |= r == self.stop_token
|
||||||
|
if stopped.all().item():
|
||||||
|
break
|
||||||
|
|
||||||
pruned = [self._prune(r) for r in resps_list]
|
pruned = [self._prune(r) for r in resps_list]
|
||||||
return pruned
|
return pruned
|
||||||
|
|
|
@ -149,6 +149,8 @@ class Base(nn.Module):
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
|
self.causal_chunk_size = 0 # 64 if self.causal else 1
|
||||||
|
|
||||||
if self.arch_type == "transformer":
|
if self.arch_type == "transformer":
|
||||||
self.sin_emb = SinusoidalEmbedding(d_model)
|
self.sin_emb = SinusoidalEmbedding(d_model)
|
||||||
|
@ -171,8 +173,8 @@ class Base(nn.Module):
|
||||||
dropout=p_dropout,
|
dropout=p_dropout,
|
||||||
checkpoint_activations=True,
|
checkpoint_activations=True,
|
||||||
|
|
||||||
chunkwise_recurrent=False, # self.causal,
|
chunkwise_recurrent=self.causal and self.causal_chunk_size > 0,
|
||||||
recurrent_chunkwise_size=64,
|
recurrent_chunkwise_size=self.causal_chunk_size,
|
||||||
no_output_layer=True,
|
no_output_layer=True,
|
||||||
decoder_normalize_before=True,
|
decoder_normalize_before=True,
|
||||||
))
|
))
|
||||||
|
@ -358,6 +360,10 @@ class Base(nn.Module):
|
||||||
elif return_all_resp:
|
elif return_all_resp:
|
||||||
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
|
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||||
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
||||||
|
# return the last chunkwise piece
|
||||||
|
elif self.causal_chunk_size > 0:
|
||||||
|
logits = [hi[-self.causal_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||||
|
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
||||||
# return just the last code
|
# return just the last code
|
||||||
else:
|
else:
|
||||||
logits = torch.stack([hi[-1] for hi in h_list])
|
logits = torch.stack([hi[-1] for hi in h_list])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user