Add trainer

This commit is contained in:
enhuiz 2023-01-12 14:41:44 +08:00
parent d19449f1f1
commit c3bacebfab
13 changed files with 639 additions and 5 deletions

1
.gitignore vendored
View File

@ -2,3 +2,4 @@ __pycache__
/data
/logs
/ckpts
/.cache

View File

@ -2,9 +2,12 @@
An unofficial (toy) implementation of VALL-E, based on the [encodec](https://github.com/facebookresearch/encodec) tokenizer.
[!["Buy Me A Coffee"](https://www.buymeacoffee.com/assets/img/custom_images/orange_img.png)](https://www.buymeacoffee.com/enhuiz)
## TODO
- [x] AR model for the first quantizer.
- [x] Audio decoding from tokens.
- [x] NAR model for the rest quantizers.
- [ ] Trainers for both models.
- [x] Trainers for both models.
- [ ] Pre-trained checkpoint.

4
config/ar.yml Normal file
View File

@ -0,0 +1,4 @@
data_dirs: [data/test]
model: ar
batch_size: 1

4
config/nar.yml Normal file
View File

@ -0,0 +1,4 @@
data_dirs: [data/test]
model: nar
batch_size: 1

Binary file not shown.

Binary file not shown.

Binary file not shown.

77
vall_e/config.py Normal file
View File

@ -0,0 +1,77 @@
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
import diskcache
from .utils import Config as ConfigBase
@dataclass(frozen=True)
class Config(ConfigBase):
data_root: Path = Path("data")
data_dirs: list[Path] = field(default_factory=lambda: [])
test_data_dirs: list[Path] = field(default_factory=lambda: [])
batch_size: int = 24
eval_batch_size: int = 12
nj: int = 8
@property
def sample_rate(self):
return 24_000
p_additional_prompt: float = 0.5
token_dim: int = 256
num_tokens: int = 1024
batch_size: int = 128
eval_batch_size: int = 512
warmup_min_lr: float = 1e-6
warmup_max_lr: float = 2e-4
dis_warmup_max_lr: float = 4e-4
warmup_num_steps: int = 1_000
max_iter: int = 10_000
gradient_clipping: float = 100
eval_every: int = 2_000
save_ckpt_every: int = 10_000
model: str = "ar"
d_model: int = 512
n_heads: int = 8
n_layers: int = 12
p_dropout: float = 0.1
@property
def ds_cfg(self):
return {
"train_micro_batch_size_per_gpu": self.batch_size,
"gradient_accumulation_steps": 1,
"optimizer": {"type": "Adam"},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": self.warmup_min_lr,
"warmup_max_lr": self.warmup_max_lr,
"warmup_num_steps": self.warmup_num_steps,
"total_num_steps": self.max_iter,
"warmup_type": "linear",
},
},
"gradient_clipping": self.gradient_clipping,
}
@property
def cache_dir(self):
return ".cache" / self.relpath
@cached_property
def diskcache(self):
return diskcache.Cache(self.cache_dir).memoize
cfg = Config.from_cli()
if __name__ == "__main__":
print(cfg)

294
vall_e/data.py Normal file
View File

@ -0,0 +1,294 @@
import copy
import logging
import random
from collections import defaultdict
from functools import cache, cached_property
from itertools import groupby, zip_longest
from typing import Any
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from .config import cfg
from .sampler import Sampler
torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__)
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_quant_path(path):
return _replace_file_extension(path, ".qnt.pt")
def _load_quants(path) -> Tensor:
"""
Returns:
quants: (t q)
"""
path = _get_quant_path(path)
return torch.load(path)[0].t()
@cache
def _get_phones(path):
path = _replace_file_extension(path, ".phn.txt")
with open(path, "r", encoding="utf8") as f:
content = f.read()
return ["<s>"] + content.split() + ["</s>"]
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
for e in l:
groups[fn(e)].append(e)
groups = {k: groups[k] for k in sorted(groups)}
for interleaved in zip_longest(*groups.values()):
for value in interleaved:
if value is not None:
yield value
@cache
def _validate(path, min_phones, max_phones):
phones = _get_phones(path)
unique_phones = list(set(phones))
if len(unique_phones) == 0:
return False
if len(unique_phones) == 1 and unique_phones[0] == "_":
return False
if len(phones) < min_phones:
return False
if len(phones) > max_phones:
return False
return True
def _get_spkr_name(path) -> str:
return path.parts[-2] # spkr/*.wav
class VALLEDatset(Dataset):
def __init__(
self,
paths,
phone_symmap=None,
spkr_symmap=None,
min_phones=10,
max_phones=100,
training=False,
extra_paths_by_spkr_name: dict[str, list] = {},
):
super().__init__()
self._head = None
self.min_phones = min_phones
self.max_phones = max_phones
self.paths = [
path for path in paths if _validate(path, self.min_phones, self.max_phones)
]
self.spkr_symmap = spkr_symmap or self._get_spkr_symmap()
self.phone_symmap = phone_symmap or self._get_phone_symmap()
self.training = training
self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
if training:
self.sampler = Sampler(self.paths, [_get_spkr_name])
else:
self.sampler = None
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
ret = defaultdict(list)
for path in self.paths:
if _get_quant_path(path).exists():
ret[_get_spkr_name(path)].append(path)
for k, v in extra_paths_by_spkr_name.items():
ret[k].extend(v)
return {**ret}
@cached_property
def phones(self):
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
def _get_phone_symmap(self):
# Note that we use phone symmap starting from 1 so that we can safely pad 0.
return {s: i for i, s in enumerate(self.phones, 1)}
@cached_property
def spkrs(self):
return sorted({_get_spkr_name(path) for path in self.paths})
def _get_spkr_symmap(self):
return {s: i for i, s in enumerate(self.spkrs)}
def sample_prompts(self, spkr_name):
prom_list = []
while (
len(prom_list) == 0
or random.random() < cfg.p_additional_prompt
and len(prom_list) < 10
):
path = random.choice(self.paths_by_spkr_name[spkr_name])
prom_list.append(_load_quants(path))
prom = torch.cat(prom_list)
return prom
def __getitem__(self, index):
if self.training:
assert self.sampler is not None
path = self.sampler.sample()
else:
path = self.paths[index]
spkr_name = _get_spkr_name(path)
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))])
proms = self.sample_prompts(spkr_name)
resps = _load_quants(path)
resp = resps[..., 0]
return dict(
path=path,
spkr_name=spkr_name,
text=text,
proms=proms,
resps=resps,
resp=resp,
)
def head_(self, n):
self._head = n
def training_(self, value):
self.training = value
def interleaved_reorder_(self, fn):
self.paths = [*_interleaved_reorder(self.paths, fn)]
def __len__(self):
return min(len(self.paths), self._head or len(self.paths))
def collate_fn(samples: list[dict]):
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
return batch
def _seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def _create_dl(dataset, training):
return DataLoader(
dataset=dataset,
batch_size=cfg.batch_size if training else cfg.eval_batch_size,
shuffle=training,
drop_last=training,
num_workers=cfg.nj,
collate_fn=collate_fn,
persistent_workers=True,
worker_init_fn=_seed_worker,
)
def _load_train_val_paths():
paths = []
train_paths = []
val_paths = []
for data_dir in cfg.data_dirs:
paths.extend(tqdm(data_dir.rglob("**/*.qnt.pt")))
if len(paths) == 0:
raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.data_dirs}.")
pairs = sorted([(_get_spkr_name(p), p) for p in paths])
del paths
for _, group in groupby(pairs, lambda pair: pair[0]):
paths = sorted([p for _, p in group])
random.seed(0)
random.shuffle(paths)
n = round(len(paths) * 0.95)
train_paths.extend(paths[:n])
val_paths.extend(paths[n:])
train_paths, val_paths = map(sorted, [train_paths, val_paths])
return train_paths, val_paths
def _load_test_paths():
test_paths = []
for data_dir in cfg.test_data_dirs:
test_paths.extend(data_dir.rglob("**/*.asr.txt"))
test_paths = sorted(test_paths)
return test_paths
@cfg.diskcache()
def create_datasets():
train_paths, val_paths = _load_train_val_paths()
test_paths = _load_test_paths()
train_dataset = VALLEDatset(train_paths, training=True)
val_dataset = VALLEDatset(
val_paths,
train_dataset.phone_symmap,
train_dataset.spkr_symmap,
extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
)
val_dataset.interleaved_reorder_(_get_spkr_name)
val_dataset.head_(200)
test_dataset = VALLEDatset(
test_paths,
train_dataset.phone_symmap,
train_dataset.spkr_symmap,
extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
)
return train_dataset, val_dataset, test_dataset
def create_train_val_dataloader():
train_dataset, val_dataset, test_dataset = create_datasets()
train_dl = _create_dl(train_dataset, training=True)
val_dl = _create_dl(val_dataset, training=False)
test_dl = _create_dl(test_dataset, training=False)
_logger.info(str(train_dataset.phone_symmap))
_logger.info(str(train_dataset.spkr_symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (test): {len(test_dataset)}.")
train200_dataset = copy.deepcopy(train_dataset)
train200_dataset.interleaved_reorder_(_get_spkr_name)
train200_dataset.head_(200)
train200_dataset.training_(False)
train200_dl = _create_dl(train200_dataset, training=False)
assert isinstance(train200_dl.dataset, VALLEDatset)
return train_dl, train200_dl, val_dl, test_dl
if __name__ == "__main__":
train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader()
sample = train_dl.dataset[0]
print(sample)

50
vall_e/emb/g2p.py Normal file
View File

@ -0,0 +1,50 @@
import argparse
import random
import string
from functools import cache
from pathlib import Path
import torch
from g2p_en import G2p
from tqdm import tqdm
@cache
def _get_model():
return G2p()
@cache
def _get_graphs(path):
with open(path, "r") as f:
graphs = f.read()
return graphs
def encode(graphs: str) -> list[str]:
g2p = _get_model()
phones = g2p(graphs)
ignored = {" ", *string.punctuation}
return ["_" if p in ignored else p for p in phones]
@torch.no_grad()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("folder", type=Path)
parser.add_argument("--suffix", type=str, default=".normalized.txt")
args = parser.parse_args()
paths = list(args.folder.rglob(f"*{args.suffix}"))
random.shuffle(paths)
for path in tqdm(paths):
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
graphs = _get_graphs(path)
phones = encode(graphs)
with open(phone_path, "w") as f:
f.write(" ".join(phones))
if __name__ == "__main__":
main()

View File

@ -2,35 +2,51 @@ import argparse
from functools import cache
from pathlib import Path
import soundfile
import torch
import torchaudio
from einops import rearrange
from encodec import EncodecModel
from encodec.utils import convert_audio
from torch import Tensor
from tqdm import tqdm
from ..config import cfg
@cache
def _load_model(device="cuda"):
# Instantiate a pretrained EnCodec model
assert cfg.sample_rate == 24_000
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
model.to(device)
return model
def unload_model():
return _load_model.cache_clear()
@torch.inference_mode()
def decode(codes: Tensor, device="cuda"):
"""
Args:
codes: (b k t)
codes: (b q t)
"""
assert codes.dim() == 3
model = _load_model(device)
return model.decode([(codes, None)]), model.sample_rate
def replace_file_extension(path, suffix):
def decode_to_file(resps: Tensor, path: Path):
assert resps.dim() == 2, f"Require shape (t q), but got {resps.shape}."
resps = rearrange(resps, "t q -> 1 q t")
wavs, sr = decode(resps)
soundfile.write(str(path), wavs.cpu()[0, 0], sr)
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
@ -46,7 +62,7 @@ def encode(wav, sr, device="cuda"):
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.to(device)
encoded_frames = model.encode(wav)
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b k t)
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
return qnt
@ -59,7 +75,7 @@ def main():
paths = [*args.folder.rglob(f"*{args.suffix}")]
for path in tqdm(paths):
out_path = replace_file_extension(path, ".qnt.pt")
out_path = _replace_file_extension(path, ".qnt.pt")
wav, sr = torchaudio.load(path)
if wav.shape[0] == 2:
wav = wav[:1]

48
vall_e/sampler.py Normal file
View File

@ -0,0 +1,48 @@
"""
A sampler that balances data by key_fns.
MIT License
Copyright (c) 2023 Zhe Niu
niuzhe.nz@outlook.com
"""
import random
class Sampler:
def __init__(self, l, key_fns):
self.tree = self._build(l, key_fns)
def _build(self, l, key_fns) -> dict[dict, list]:
if not key_fns:
return l
tree = {}
key_fn, *key_fns = key_fns
for x in l:
k = key_fn(x)
if k in tree:
tree[k].append(x)
else:
tree[k] = [x]
for k in tree:
tree[k] = self._build(tree[k], key_fns)
return tree
def _sample(self, tree: dict | list):
if isinstance(tree, list):
ret = random.choice(tree)
else:
key = random.choice([*tree.keys()])
ret = self._sample(tree[key])
return ret
def sample(self):
return self._sample(self.tree)

137
vall_e/train.py Normal file
View File

@ -0,0 +1,137 @@
import json
import logging
from collections import defaultdict
import torch
from tqdm import tqdm
from .config import cfg
from .data import create_train_val_dataloader
from .emb import qnt
from .utils import setup_logging, to_device, trainer
from .vall_e import AR, NAR
_logger = logging.getLogger(__name__)
def load_engines():
if cfg.model.lower() == "ar":
model = AR(
cfg.num_tokens,
cfg.d_model,
cfg.n_heads,
cfg.n_layers,
cfg.p_dropout,
)
elif cfg.model.lower() == "nar":
model = NAR(
cfg.num_tokens,
cfg.d_model,
cfg.n_heads,
cfg.n_layers,
cfg.p_dropout,
)
else:
raise NotImplementedError(cfg.model)
engines = dict(
model=trainer.Engine(
model=model,
config=cfg.ds_cfg,
),
)
return trainer.load_engines(engines, cfg)
def main():
setup_logging(cfg.log_dir)
train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader()
def train_feeder(engines, batch, name):
model = engines["model"]
if cfg.model == "ar":
_ = model(
text_list=batch["text"],
proms_list=batch["proms"],
resp_list=batch["resp"],
)
elif cfg.model == "nar":
_ = model(
text_list=batch["text"],
proms_list=batch["proms"],
resps_list=batch["resps"],
)
losses = model.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= engines.gather_attribute("scalar")
return loss, stats
@torch.inference_mode()
def run_eval(engines, name, dl):
log_dir = cfg.log_dir / str(engines.global_step) / name
model = engines["model"]
log_dir = cfg.log_dir / str(engines.global_step) / name
stats = defaultdict(list)
for batch in tqdm(dl):
batch: dict
batch = to_device(batch, cfg.device)
if cfg.model == "ar":
resp_list = model(text_list=batch["text"], proms_list=batch["proms"])
resps_list = [r.unsqueeze(-1) for r in resp_list]
elif cfg.model == "nar":
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
resp_list=batch["resp"],
)
else:
raise NotImplementedError(cfg.model)
losses = model.gather_attribute("loss")
batch_stats = {k: v.item() for k, v in losses.items()}
for k, v in batch_stats.items():
stats[k].append(v)
for path, ref, hyp in zip(batch["path"], batch["resps"], resps_list):
relpath = path.relative_to(cfg.data_root)
hyp_path = (log_dir / "hyp" / relpath).with_suffix(".wav")
ref_path = (log_dir / "ref" / relpath).with_suffix(".wav")
hyp_path.parent.mkdir(parents=True, exist_ok=True)
ref_path.parent.mkdir(parents=True, exist_ok=True)
qnt.decode_to_file(ref, ref_path)
if len(hyp) > 0:
qnt.decode_to_file(hyp, hyp_path)
stats = {k: sum(v) / len(v) for k, v in stats.items()}
stats["global_step"] = engines.global_step
stats["name"] = name
_logger.info(f"Eval: {stats}.")
_logger.info(f"{json.dumps(stats)}.")
def eval_fn(engines):
run_eval(engines, "train200", train200_dl)
run_eval(engines, "val", val_dl)
run_eval(engines, "test", test_dl)
trainer.train(
engines_loader=load_engines,
train_dl=train_dl,
train_feeder=train_feeder,
eval_fn=eval_fn,
)
if __name__ == "__main__":
main()