diff --git a/.gitignore b/.gitignore index 6de8fd2..bbb955e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__ /data /logs /ckpts +/.cache diff --git a/README.md b/README.md index eb7d707..2a4f723 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,12 @@ An unofficial (toy) implementation of VALL-E, based on the [encodec](https://github.com/facebookresearch/encodec) tokenizer. +[!["Buy Me A Coffee"](https://www.buymeacoffee.com/assets/img/custom_images/orange_img.png)](https://www.buymeacoffee.com/enhuiz) + ## TODO - [x] AR model for the first quantizer. - [x] Audio decoding from tokens. - [x] NAR model for the rest quantizers. -- [ ] Trainers for both models. +- [x] Trainers for both models. +- [ ] Pre-trained checkpoint. diff --git a/config/ar.yml b/config/ar.yml new file mode 100644 index 0000000..ad0266e --- /dev/null +++ b/config/ar.yml @@ -0,0 +1,4 @@ +data_dirs: [data/test] + +model: ar +batch_size: 1 diff --git a/config/nar.yml b/config/nar.yml new file mode 100644 index 0000000..f5a78e7 --- /dev/null +++ b/config/nar.yml @@ -0,0 +1,4 @@ +data_dirs: [data/test] + +model: nar +batch_size: 1 diff --git a/data/test/test.ar.recon.wav b/data/test/test.ar.recon.wav index 3fd7f77..5ee5fe2 100644 Binary files a/data/test/test.ar.recon.wav and b/data/test/test.ar.recon.wav differ diff --git a/data/test/test.nar.init.wav b/data/test/test.nar.init.wav index 7709797..eb71667 100644 Binary files a/data/test/test.nar.init.wav and b/data/test/test.nar.init.wav differ diff --git a/data/test/test.nar.recon.wav b/data/test/test.nar.recon.wav index 7de0c4b..bbc09f0 100644 Binary files a/data/test/test.nar.recon.wav and b/data/test/test.nar.recon.wav differ diff --git a/vall_e/config.py b/vall_e/config.py new file mode 100644 index 0000000..bf3d6c0 --- /dev/null +++ b/vall_e/config.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass, field +from functools import cached_property +from pathlib import Path + +import diskcache + +from .utils import Config as ConfigBase + + +@dataclass(frozen=True) +class Config(ConfigBase): + data_root: Path = Path("data") + data_dirs: list[Path] = field(default_factory=lambda: []) + test_data_dirs: list[Path] = field(default_factory=lambda: []) + + batch_size: int = 24 + eval_batch_size: int = 12 + nj: int = 8 + + @property + def sample_rate(self): + return 24_000 + + p_additional_prompt: float = 0.5 + + token_dim: int = 256 + num_tokens: int = 1024 + + batch_size: int = 128 + eval_batch_size: int = 512 + warmup_min_lr: float = 1e-6 + warmup_max_lr: float = 2e-4 + dis_warmup_max_lr: float = 4e-4 + warmup_num_steps: int = 1_000 + max_iter: int = 10_000 + gradient_clipping: float = 100 + eval_every: int = 2_000 + save_ckpt_every: int = 10_000 + + model: str = "ar" + d_model: int = 512 + n_heads: int = 8 + n_layers: int = 12 + p_dropout: float = 0.1 + + @property + def ds_cfg(self): + return { + "train_micro_batch_size_per_gpu": self.batch_size, + "gradient_accumulation_steps": 1, + "optimizer": {"type": "Adam"}, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": self.warmup_min_lr, + "warmup_max_lr": self.warmup_max_lr, + "warmup_num_steps": self.warmup_num_steps, + "total_num_steps": self.max_iter, + "warmup_type": "linear", + }, + }, + "gradient_clipping": self.gradient_clipping, + } + + @property + def cache_dir(self): + return ".cache" / self.relpath + + @cached_property + def diskcache(self): + return diskcache.Cache(self.cache_dir).memoize + + +cfg = Config.from_cli() + +if __name__ == "__main__": + print(cfg) diff --git a/vall_e/data.py b/vall_e/data.py new file mode 100644 index 0000000..8dde4a9 --- /dev/null +++ b/vall_e/data.py @@ -0,0 +1,294 @@ +import copy +import logging +import random +from collections import defaultdict +from functools import cache, cached_property +from itertools import groupby, zip_longest +from typing import Any + +import numpy as np +import torch +from torch import Tensor +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from .config import cfg +from .sampler import Sampler + +torch.multiprocessing.set_sharing_strategy("file_system") + +_logger = logging.getLogger(__name__) + + +def _replace_file_extension(path, suffix): + return (path.parent / path.name.split(".")[0]).with_suffix(suffix) + + +def _get_quant_path(path): + return _replace_file_extension(path, ".qnt.pt") + + +def _load_quants(path) -> Tensor: + """ + Returns: + quants: (t q) + """ + path = _get_quant_path(path) + return torch.load(path)[0].t() + + +@cache +def _get_phones(path): + path = _replace_file_extension(path, ".phn.txt") + with open(path, "r", encoding="utf8") as f: + content = f.read() + return [""] + content.split() + [""] + + +def _interleaved_reorder(l, fn): + groups = defaultdict(list) + for e in l: + groups[fn(e)].append(e) + groups = {k: groups[k] for k in sorted(groups)} + for interleaved in zip_longest(*groups.values()): + for value in interleaved: + if value is not None: + yield value + + +@cache +def _validate(path, min_phones, max_phones): + phones = _get_phones(path) + unique_phones = list(set(phones)) + if len(unique_phones) == 0: + return False + if len(unique_phones) == 1 and unique_phones[0] == "_": + return False + if len(phones) < min_phones: + return False + if len(phones) > max_phones: + return False + return True + + +def _get_spkr_name(path) -> str: + return path.parts[-2] # spkr/*.wav + + +class VALLEDatset(Dataset): + def __init__( + self, + paths, + phone_symmap=None, + spkr_symmap=None, + min_phones=10, + max_phones=100, + training=False, + extra_paths_by_spkr_name: dict[str, list] = {}, + ): + super().__init__() + self._head = None + self.min_phones = min_phones + self.max_phones = max_phones + self.paths = [ + path for path in paths if _validate(path, self.min_phones, self.max_phones) + ] + self.spkr_symmap = spkr_symmap or self._get_spkr_symmap() + self.phone_symmap = phone_symmap or self._get_phone_symmap() + self.training = training + + self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name) + + if training: + self.sampler = Sampler(self.paths, [_get_spkr_name]) + else: + self.sampler = None + + def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]): + ret = defaultdict(list) + for path in self.paths: + if _get_quant_path(path).exists(): + ret[_get_spkr_name(path)].append(path) + for k, v in extra_paths_by_spkr_name.items(): + ret[k].extend(v) + return {**ret} + + @cached_property + def phones(self): + return sorted(set().union(*[_get_phones(path) for path in self.paths])) + + def _get_phone_symmap(self): + # Note that we use phone symmap starting from 1 so that we can safely pad 0. + return {s: i for i, s in enumerate(self.phones, 1)} + + @cached_property + def spkrs(self): + return sorted({_get_spkr_name(path) for path in self.paths}) + + def _get_spkr_symmap(self): + return {s: i for i, s in enumerate(self.spkrs)} + + def sample_prompts(self, spkr_name): + prom_list = [] + + while ( + len(prom_list) == 0 + or random.random() < cfg.p_additional_prompt + and len(prom_list) < 10 + ): + path = random.choice(self.paths_by_spkr_name[spkr_name]) + prom_list.append(_load_quants(path)) + + prom = torch.cat(prom_list) + + return prom + + def __getitem__(self, index): + if self.training: + assert self.sampler is not None + path = self.sampler.sample() + else: + path = self.paths[index] + + spkr_name = _get_spkr_name(path) + text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]) + proms = self.sample_prompts(spkr_name) + resps = _load_quants(path) + resp = resps[..., 0] + + return dict( + path=path, + spkr_name=spkr_name, + text=text, + proms=proms, + resps=resps, + resp=resp, + ) + + def head_(self, n): + self._head = n + + def training_(self, value): + self.training = value + + def interleaved_reorder_(self, fn): + self.paths = [*_interleaved_reorder(self.paths, fn)] + + def __len__(self): + return min(len(self.paths), self._head or len(self.paths)) + + +def collate_fn(samples: list[dict]): + batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]} + return batch + + +def _seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +def _create_dl(dataset, training): + return DataLoader( + dataset=dataset, + batch_size=cfg.batch_size if training else cfg.eval_batch_size, + shuffle=training, + drop_last=training, + num_workers=cfg.nj, + collate_fn=collate_fn, + persistent_workers=True, + worker_init_fn=_seed_worker, + ) + + +def _load_train_val_paths(): + paths = [] + train_paths = [] + val_paths = [] + + for data_dir in cfg.data_dirs: + paths.extend(tqdm(data_dir.rglob("**/*.qnt.pt"))) + + if len(paths) == 0: + raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.data_dirs}.") + + pairs = sorted([(_get_spkr_name(p), p) for p in paths]) + del paths + + for _, group in groupby(pairs, lambda pair: pair[0]): + paths = sorted([p for _, p in group]) + random.seed(0) + random.shuffle(paths) + n = round(len(paths) * 0.95) + train_paths.extend(paths[:n]) + val_paths.extend(paths[n:]) + + train_paths, val_paths = map(sorted, [train_paths, val_paths]) + + return train_paths, val_paths + + +def _load_test_paths(): + test_paths = [] + for data_dir in cfg.test_data_dirs: + test_paths.extend(data_dir.rglob("**/*.asr.txt")) + test_paths = sorted(test_paths) + return test_paths + + +@cfg.diskcache() +def create_datasets(): + train_paths, val_paths = _load_train_val_paths() + test_paths = _load_test_paths() + + train_dataset = VALLEDatset(train_paths, training=True) + + val_dataset = VALLEDatset( + val_paths, + train_dataset.phone_symmap, + train_dataset.spkr_symmap, + extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name, + ) + + val_dataset.interleaved_reorder_(_get_spkr_name) + val_dataset.head_(200) + + test_dataset = VALLEDatset( + test_paths, + train_dataset.phone_symmap, + train_dataset.spkr_symmap, + extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name, + ) + + return train_dataset, val_dataset, test_dataset + + +def create_train_val_dataloader(): + train_dataset, val_dataset, test_dataset = create_datasets() + + train_dl = _create_dl(train_dataset, training=True) + val_dl = _create_dl(val_dataset, training=False) + test_dl = _create_dl(test_dataset, training=False) + + _logger.info(str(train_dataset.phone_symmap)) + _logger.info(str(train_dataset.spkr_symmap)) + + _logger.info(f"#samples (train): {len(train_dataset)}.") + _logger.info(f"#samples (val): {len(val_dataset)}.") + _logger.info(f"#samples (test): {len(test_dataset)}.") + + train200_dataset = copy.deepcopy(train_dataset) + train200_dataset.interleaved_reorder_(_get_spkr_name) + train200_dataset.head_(200) + train200_dataset.training_(False) + train200_dl = _create_dl(train200_dataset, training=False) + assert isinstance(train200_dl.dataset, VALLEDatset) + + return train_dl, train200_dl, val_dl, test_dl + + +if __name__ == "__main__": + train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader() + sample = train_dl.dataset[0] + print(sample) diff --git a/vall_e/emb/g2p.py b/vall_e/emb/g2p.py new file mode 100644 index 0000000..c98758a --- /dev/null +++ b/vall_e/emb/g2p.py @@ -0,0 +1,50 @@ +import argparse +import random +import string +from functools import cache +from pathlib import Path + +import torch +from g2p_en import G2p +from tqdm import tqdm + + +@cache +def _get_model(): + return G2p() + + +@cache +def _get_graphs(path): + with open(path, "r") as f: + graphs = f.read() + return graphs + + +def encode(graphs: str) -> list[str]: + g2p = _get_model() + phones = g2p(graphs) + ignored = {" ", *string.punctuation} + return ["_" if p in ignored else p for p in phones] + + +@torch.no_grad() +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("folder", type=Path) + parser.add_argument("--suffix", type=str, default=".normalized.txt") + args = parser.parse_args() + + paths = list(args.folder.rglob(f"*{args.suffix}")) + random.shuffle(paths) + + for path in tqdm(paths): + phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt") + graphs = _get_graphs(path) + phones = encode(graphs) + with open(phone_path, "w") as f: + f.write(" ".join(phones)) + + +if __name__ == "__main__": + main() diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 1cba87e..b70951a 100644 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -2,35 +2,51 @@ import argparse from functools import cache from pathlib import Path +import soundfile import torch import torchaudio +from einops import rearrange from encodec import EncodecModel from encodec.utils import convert_audio from torch import Tensor from tqdm import tqdm +from ..config import cfg + @cache def _load_model(device="cuda"): # Instantiate a pretrained EnCodec model + assert cfg.sample_rate == 24_000 model = EncodecModel.encodec_model_24khz() model.set_target_bandwidth(6.0) model.to(device) return model +def unload_model(): + return _load_model.cache_clear() + + @torch.inference_mode() def decode(codes: Tensor, device="cuda"): """ Args: - codes: (b k t) + codes: (b q t) """ assert codes.dim() == 3 model = _load_model(device) return model.decode([(codes, None)]), model.sample_rate -def replace_file_extension(path, suffix): +def decode_to_file(resps: Tensor, path: Path): + assert resps.dim() == 2, f"Require shape (t q), but got {resps.shape}." + resps = rearrange(resps, "t q -> 1 q t") + wavs, sr = decode(resps) + soundfile.write(str(path), wavs.cpu()[0, 0], sr) + + +def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) @@ -46,7 +62,7 @@ def encode(wav, sr, device="cuda"): wav = convert_audio(wav, sr, model.sample_rate, model.channels) wav = wav.to(device) encoded_frames = model.encode(wav) - qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b k t) + qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t) return qnt @@ -59,7 +75,7 @@ def main(): paths = [*args.folder.rglob(f"*{args.suffix}")] for path in tqdm(paths): - out_path = replace_file_extension(path, ".qnt.pt") + out_path = _replace_file_extension(path, ".qnt.pt") wav, sr = torchaudio.load(path) if wav.shape[0] == 2: wav = wav[:1] diff --git a/vall_e/sampler.py b/vall_e/sampler.py new file mode 100644 index 0000000..b979ebb --- /dev/null +++ b/vall_e/sampler.py @@ -0,0 +1,48 @@ +""" +A sampler that balances data by key_fns. + +MIT License + +Copyright (c) 2023 Zhe Niu + +niuzhe.nz@outlook.com +""" + +import random + + +class Sampler: + def __init__(self, l, key_fns): + self.tree = self._build(l, key_fns) + + def _build(self, l, key_fns) -> dict[dict, list]: + if not key_fns: + return l + + tree = {} + + key_fn, *key_fns = key_fns + + for x in l: + k = key_fn(x) + + if k in tree: + tree[k].append(x) + else: + tree[k] = [x] + + for k in tree: + tree[k] = self._build(tree[k], key_fns) + + return tree + + def _sample(self, tree: dict | list): + if isinstance(tree, list): + ret = random.choice(tree) + else: + key = random.choice([*tree.keys()]) + ret = self._sample(tree[key]) + return ret + + def sample(self): + return self._sample(self.tree) diff --git a/vall_e/train.py b/vall_e/train.py new file mode 100644 index 0000000..a5a721f --- /dev/null +++ b/vall_e/train.py @@ -0,0 +1,137 @@ +import json +import logging +from collections import defaultdict + +import torch +from tqdm import tqdm + +from .config import cfg +from .data import create_train_val_dataloader +from .emb import qnt +from .utils import setup_logging, to_device, trainer +from .vall_e import AR, NAR + +_logger = logging.getLogger(__name__) + + +def load_engines(): + if cfg.model.lower() == "ar": + model = AR( + cfg.num_tokens, + cfg.d_model, + cfg.n_heads, + cfg.n_layers, + cfg.p_dropout, + ) + elif cfg.model.lower() == "nar": + model = NAR( + cfg.num_tokens, + cfg.d_model, + cfg.n_heads, + cfg.n_layers, + cfg.p_dropout, + ) + else: + raise NotImplementedError(cfg.model) + + engines = dict( + model=trainer.Engine( + model=model, + config=cfg.ds_cfg, + ), + ) + + return trainer.load_engines(engines, cfg) + + +def main(): + setup_logging(cfg.log_dir) + + train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader() + + def train_feeder(engines, batch, name): + model = engines["model"] + + if cfg.model == "ar": + _ = model( + text_list=batch["text"], + proms_list=batch["proms"], + resp_list=batch["resp"], + ) + elif cfg.model == "nar": + _ = model( + text_list=batch["text"], + proms_list=batch["proms"], + resps_list=batch["resps"], + ) + + losses = model.gather_attribute("loss") + + loss = torch.stack([*losses.values()]).sum() + + stats = {} + stats |= {k: v.item() for k, v in losses.items()} + stats |= engines.gather_attribute("scalar") + + return loss, stats + + @torch.inference_mode() + def run_eval(engines, name, dl): + log_dir = cfg.log_dir / str(engines.global_step) / name + + model = engines["model"] + log_dir = cfg.log_dir / str(engines.global_step) / name + stats = defaultdict(list) + for batch in tqdm(dl): + batch: dict + batch = to_device(batch, cfg.device) + + if cfg.model == "ar": + resp_list = model(text_list=batch["text"], proms_list=batch["proms"]) + resps_list = [r.unsqueeze(-1) for r in resp_list] + elif cfg.model == "nar": + resps_list = model( + text_list=batch["text"], + proms_list=batch["proms"], + resp_list=batch["resp"], + ) + else: + raise NotImplementedError(cfg.model) + + losses = model.gather_attribute("loss") + batch_stats = {k: v.item() for k, v in losses.items()} + for k, v in batch_stats.items(): + stats[k].append(v) + + for path, ref, hyp in zip(batch["path"], batch["resps"], resps_list): + relpath = path.relative_to(cfg.data_root) + hyp_path = (log_dir / "hyp" / relpath).with_suffix(".wav") + ref_path = (log_dir / "ref" / relpath).with_suffix(".wav") + hyp_path.parent.mkdir(parents=True, exist_ok=True) + ref_path.parent.mkdir(parents=True, exist_ok=True) + qnt.decode_to_file(ref, ref_path) + if len(hyp) > 0: + qnt.decode_to_file(hyp, hyp_path) + + stats = {k: sum(v) / len(v) for k, v in stats.items()} + stats["global_step"] = engines.global_step + stats["name"] = name + _logger.info(f"Eval: {stats}.") + + _logger.info(f"{json.dumps(stats)}.") + + def eval_fn(engines): + run_eval(engines, "train200", train200_dl) + run_eval(engines, "val", val_dl) + run_eval(engines, "test", test_dl) + + trainer.train( + engines_loader=load_engines, + train_dl=train_dl, + train_feeder=train_feeder, + eval_fn=eval_fn, + ) + + +if __name__ == "__main__": + main()