Add trainer
This commit is contained in:
parent
d19449f1f1
commit
c3bacebfab
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,3 +2,4 @@ __pycache__
|
||||||
/data
|
/data
|
||||||
/logs
|
/logs
|
||||||
/ckpts
|
/ckpts
|
||||||
|
/.cache
|
||||||
|
|
|
@ -2,9 +2,12 @@
|
||||||
|
|
||||||
An unofficial (toy) implementation of VALL-E, based on the [encodec](https://github.com/facebookresearch/encodec) tokenizer.
|
An unofficial (toy) implementation of VALL-E, based on the [encodec](https://github.com/facebookresearch/encodec) tokenizer.
|
||||||
|
|
||||||
|
[](https://www.buymeacoffee.com/enhuiz)
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
- [x] AR model for the first quantizer.
|
- [x] AR model for the first quantizer.
|
||||||
- [x] Audio decoding from tokens.
|
- [x] Audio decoding from tokens.
|
||||||
- [x] NAR model for the rest quantizers.
|
- [x] NAR model for the rest quantizers.
|
||||||
- [ ] Trainers for both models.
|
- [x] Trainers for both models.
|
||||||
|
- [ ] Pre-trained checkpoint.
|
||||||
|
|
4
config/ar.yml
Normal file
4
config/ar.yml
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
data_dirs: [data/test]
|
||||||
|
|
||||||
|
model: ar
|
||||||
|
batch_size: 1
|
4
config/nar.yml
Normal file
4
config/nar.yml
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
data_dirs: [data/test]
|
||||||
|
|
||||||
|
model: nar
|
||||||
|
batch_size: 1
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
77
vall_e/config.py
Normal file
77
vall_e/config.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import cached_property
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import diskcache
|
||||||
|
|
||||||
|
from .utils import Config as ConfigBase
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Config(ConfigBase):
|
||||||
|
data_root: Path = Path("data")
|
||||||
|
data_dirs: list[Path] = field(default_factory=lambda: [])
|
||||||
|
test_data_dirs: list[Path] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
batch_size: int = 24
|
||||||
|
eval_batch_size: int = 12
|
||||||
|
nj: int = 8
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self):
|
||||||
|
return 24_000
|
||||||
|
|
||||||
|
p_additional_prompt: float = 0.5
|
||||||
|
|
||||||
|
token_dim: int = 256
|
||||||
|
num_tokens: int = 1024
|
||||||
|
|
||||||
|
batch_size: int = 128
|
||||||
|
eval_batch_size: int = 512
|
||||||
|
warmup_min_lr: float = 1e-6
|
||||||
|
warmup_max_lr: float = 2e-4
|
||||||
|
dis_warmup_max_lr: float = 4e-4
|
||||||
|
warmup_num_steps: int = 1_000
|
||||||
|
max_iter: int = 10_000
|
||||||
|
gradient_clipping: float = 100
|
||||||
|
eval_every: int = 2_000
|
||||||
|
save_ckpt_every: int = 10_000
|
||||||
|
|
||||||
|
model: str = "ar"
|
||||||
|
d_model: int = 512
|
||||||
|
n_heads: int = 8
|
||||||
|
n_layers: int = 12
|
||||||
|
p_dropout: float = 0.1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ds_cfg(self):
|
||||||
|
return {
|
||||||
|
"train_micro_batch_size_per_gpu": self.batch_size,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"optimizer": {"type": "Adam"},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": self.warmup_min_lr,
|
||||||
|
"warmup_max_lr": self.warmup_max_lr,
|
||||||
|
"warmup_num_steps": self.warmup_num_steps,
|
||||||
|
"total_num_steps": self.max_iter,
|
||||||
|
"warmup_type": "linear",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"gradient_clipping": self.gradient_clipping,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_dir(self):
|
||||||
|
return ".cache" / self.relpath
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def diskcache(self):
|
||||||
|
return diskcache.Cache(self.cache_dir).memoize
|
||||||
|
|
||||||
|
|
||||||
|
cfg = Config.from_cli()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(cfg)
|
294
vall_e/data.py
Normal file
294
vall_e/data.py
Normal file
|
@ -0,0 +1,294 @@
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import cache, cached_property
|
||||||
|
from itertools import groupby, zip_longest
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .config import cfg
|
||||||
|
from .sampler import Sampler
|
||||||
|
|
||||||
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_file_extension(path, suffix):
|
||||||
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_quant_path(path):
|
||||||
|
return _replace_file_extension(path, ".qnt.pt")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_quants(path) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
quants: (t q)
|
||||||
|
"""
|
||||||
|
path = _get_quant_path(path)
|
||||||
|
return torch.load(path)[0].t()
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _get_phones(path):
|
||||||
|
path = _replace_file_extension(path, ".phn.txt")
|
||||||
|
with open(path, "r", encoding="utf8") as f:
|
||||||
|
content = f.read()
|
||||||
|
return ["<s>"] + content.split() + ["</s>"]
|
||||||
|
|
||||||
|
|
||||||
|
def _interleaved_reorder(l, fn):
|
||||||
|
groups = defaultdict(list)
|
||||||
|
for e in l:
|
||||||
|
groups[fn(e)].append(e)
|
||||||
|
groups = {k: groups[k] for k in sorted(groups)}
|
||||||
|
for interleaved in zip_longest(*groups.values()):
|
||||||
|
for value in interleaved:
|
||||||
|
if value is not None:
|
||||||
|
yield value
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _validate(path, min_phones, max_phones):
|
||||||
|
phones = _get_phones(path)
|
||||||
|
unique_phones = list(set(phones))
|
||||||
|
if len(unique_phones) == 0:
|
||||||
|
return False
|
||||||
|
if len(unique_phones) == 1 and unique_phones[0] == "_":
|
||||||
|
return False
|
||||||
|
if len(phones) < min_phones:
|
||||||
|
return False
|
||||||
|
if len(phones) > max_phones:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _get_spkr_name(path) -> str:
|
||||||
|
return path.parts[-2] # spkr/*.wav
|
||||||
|
|
||||||
|
|
||||||
|
class VALLEDatset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
paths,
|
||||||
|
phone_symmap=None,
|
||||||
|
spkr_symmap=None,
|
||||||
|
min_phones=10,
|
||||||
|
max_phones=100,
|
||||||
|
training=False,
|
||||||
|
extra_paths_by_spkr_name: dict[str, list] = {},
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._head = None
|
||||||
|
self.min_phones = min_phones
|
||||||
|
self.max_phones = max_phones
|
||||||
|
self.paths = [
|
||||||
|
path for path in paths if _validate(path, self.min_phones, self.max_phones)
|
||||||
|
]
|
||||||
|
self.spkr_symmap = spkr_symmap or self._get_spkr_symmap()
|
||||||
|
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
||||||
|
self.training = training
|
||||||
|
|
||||||
|
self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
|
||||||
|
|
||||||
|
if training:
|
||||||
|
self.sampler = Sampler(self.paths, [_get_spkr_name])
|
||||||
|
else:
|
||||||
|
self.sampler = None
|
||||||
|
|
||||||
|
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
|
||||||
|
ret = defaultdict(list)
|
||||||
|
for path in self.paths:
|
||||||
|
if _get_quant_path(path).exists():
|
||||||
|
ret[_get_spkr_name(path)].append(path)
|
||||||
|
for k, v in extra_paths_by_spkr_name.items():
|
||||||
|
ret[k].extend(v)
|
||||||
|
return {**ret}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def phones(self):
|
||||||
|
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
||||||
|
|
||||||
|
def _get_phone_symmap(self):
|
||||||
|
# Note that we use phone symmap starting from 1 so that we can safely pad 0.
|
||||||
|
return {s: i for i, s in enumerate(self.phones, 1)}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def spkrs(self):
|
||||||
|
return sorted({_get_spkr_name(path) for path in self.paths})
|
||||||
|
|
||||||
|
def _get_spkr_symmap(self):
|
||||||
|
return {s: i for i, s in enumerate(self.spkrs)}
|
||||||
|
|
||||||
|
def sample_prompts(self, spkr_name):
|
||||||
|
prom_list = []
|
||||||
|
|
||||||
|
while (
|
||||||
|
len(prom_list) == 0
|
||||||
|
or random.random() < cfg.p_additional_prompt
|
||||||
|
and len(prom_list) < 10
|
||||||
|
):
|
||||||
|
path = random.choice(self.paths_by_spkr_name[spkr_name])
|
||||||
|
prom_list.append(_load_quants(path))
|
||||||
|
|
||||||
|
prom = torch.cat(prom_list)
|
||||||
|
|
||||||
|
return prom
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.training:
|
||||||
|
assert self.sampler is not None
|
||||||
|
path = self.sampler.sample()
|
||||||
|
else:
|
||||||
|
path = self.paths[index]
|
||||||
|
|
||||||
|
spkr_name = _get_spkr_name(path)
|
||||||
|
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))])
|
||||||
|
proms = self.sample_prompts(spkr_name)
|
||||||
|
resps = _load_quants(path)
|
||||||
|
resp = resps[..., 0]
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
path=path,
|
||||||
|
spkr_name=spkr_name,
|
||||||
|
text=text,
|
||||||
|
proms=proms,
|
||||||
|
resps=resps,
|
||||||
|
resp=resp,
|
||||||
|
)
|
||||||
|
|
||||||
|
def head_(self, n):
|
||||||
|
self._head = n
|
||||||
|
|
||||||
|
def training_(self, value):
|
||||||
|
self.training = value
|
||||||
|
|
||||||
|
def interleaved_reorder_(self, fn):
|
||||||
|
self.paths = [*_interleaved_reorder(self.paths, fn)]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return min(len(self.paths), self._head or len(self.paths))
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(samples: list[dict]):
|
||||||
|
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_worker(worker_id):
|
||||||
|
worker_seed = torch.initial_seed() % 2**32
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_dl(dataset, training):
|
||||||
|
return DataLoader(
|
||||||
|
dataset=dataset,
|
||||||
|
batch_size=cfg.batch_size if training else cfg.eval_batch_size,
|
||||||
|
shuffle=training,
|
||||||
|
drop_last=training,
|
||||||
|
num_workers=cfg.nj,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
persistent_workers=True,
|
||||||
|
worker_init_fn=_seed_worker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_train_val_paths():
|
||||||
|
paths = []
|
||||||
|
train_paths = []
|
||||||
|
val_paths = []
|
||||||
|
|
||||||
|
for data_dir in cfg.data_dirs:
|
||||||
|
paths.extend(tqdm(data_dir.rglob("**/*.qnt.pt")))
|
||||||
|
|
||||||
|
if len(paths) == 0:
|
||||||
|
raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.data_dirs}.")
|
||||||
|
|
||||||
|
pairs = sorted([(_get_spkr_name(p), p) for p in paths])
|
||||||
|
del paths
|
||||||
|
|
||||||
|
for _, group in groupby(pairs, lambda pair: pair[0]):
|
||||||
|
paths = sorted([p for _, p in group])
|
||||||
|
random.seed(0)
|
||||||
|
random.shuffle(paths)
|
||||||
|
n = round(len(paths) * 0.95)
|
||||||
|
train_paths.extend(paths[:n])
|
||||||
|
val_paths.extend(paths[n:])
|
||||||
|
|
||||||
|
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
||||||
|
|
||||||
|
return train_paths, val_paths
|
||||||
|
|
||||||
|
|
||||||
|
def _load_test_paths():
|
||||||
|
test_paths = []
|
||||||
|
for data_dir in cfg.test_data_dirs:
|
||||||
|
test_paths.extend(data_dir.rglob("**/*.asr.txt"))
|
||||||
|
test_paths = sorted(test_paths)
|
||||||
|
return test_paths
|
||||||
|
|
||||||
|
|
||||||
|
@cfg.diskcache()
|
||||||
|
def create_datasets():
|
||||||
|
train_paths, val_paths = _load_train_val_paths()
|
||||||
|
test_paths = _load_test_paths()
|
||||||
|
|
||||||
|
train_dataset = VALLEDatset(train_paths, training=True)
|
||||||
|
|
||||||
|
val_dataset = VALLEDatset(
|
||||||
|
val_paths,
|
||||||
|
train_dataset.phone_symmap,
|
||||||
|
train_dataset.spkr_symmap,
|
||||||
|
extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataset.interleaved_reorder_(_get_spkr_name)
|
||||||
|
val_dataset.head_(200)
|
||||||
|
|
||||||
|
test_dataset = VALLEDatset(
|
||||||
|
test_paths,
|
||||||
|
train_dataset.phone_symmap,
|
||||||
|
train_dataset.spkr_symmap,
|
||||||
|
extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dataset, val_dataset, test_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def create_train_val_dataloader():
|
||||||
|
train_dataset, val_dataset, test_dataset = create_datasets()
|
||||||
|
|
||||||
|
train_dl = _create_dl(train_dataset, training=True)
|
||||||
|
val_dl = _create_dl(val_dataset, training=False)
|
||||||
|
test_dl = _create_dl(test_dataset, training=False)
|
||||||
|
|
||||||
|
_logger.info(str(train_dataset.phone_symmap))
|
||||||
|
_logger.info(str(train_dataset.spkr_symmap))
|
||||||
|
|
||||||
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||||
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||||
|
_logger.info(f"#samples (test): {len(test_dataset)}.")
|
||||||
|
|
||||||
|
train200_dataset = copy.deepcopy(train_dataset)
|
||||||
|
train200_dataset.interleaved_reorder_(_get_spkr_name)
|
||||||
|
train200_dataset.head_(200)
|
||||||
|
train200_dataset.training_(False)
|
||||||
|
train200_dl = _create_dl(train200_dataset, training=False)
|
||||||
|
assert isinstance(train200_dl.dataset, VALLEDatset)
|
||||||
|
|
||||||
|
return train_dl, train200_dl, val_dl, test_dl
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader()
|
||||||
|
sample = train_dl.dataset[0]
|
||||||
|
print(sample)
|
50
vall_e/emb/g2p.py
Normal file
50
vall_e/emb/g2p.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
from functools import cache
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from g2p_en import G2p
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _get_model():
|
||||||
|
return G2p()
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _get_graphs(path):
|
||||||
|
with open(path, "r") as f:
|
||||||
|
graphs = f.read()
|
||||||
|
return graphs
|
||||||
|
|
||||||
|
|
||||||
|
def encode(graphs: str) -> list[str]:
|
||||||
|
g2p = _get_model()
|
||||||
|
phones = g2p(graphs)
|
||||||
|
ignored = {" ", *string.punctuation}
|
||||||
|
return ["_" if p in ignored else p for p in phones]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("folder", type=Path)
|
||||||
|
parser.add_argument("--suffix", type=str, default=".normalized.txt")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
paths = list(args.folder.rglob(f"*{args.suffix}"))
|
||||||
|
random.shuffle(paths)
|
||||||
|
|
||||||
|
for path in tqdm(paths):
|
||||||
|
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
|
||||||
|
graphs = _get_graphs(path)
|
||||||
|
phones = encode(graphs)
|
||||||
|
with open(phone_path, "w") as f:
|
||||||
|
f.write(" ".join(phones))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -2,35 +2,51 @@ import argparse
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import soundfile
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
from einops import rearrange
|
||||||
from encodec import EncodecModel
|
from encodec import EncodecModel
|
||||||
from encodec.utils import convert_audio
|
from encodec.utils import convert_audio
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ..config import cfg
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _load_model(device="cuda"):
|
def _load_model(device="cuda"):
|
||||||
# Instantiate a pretrained EnCodec model
|
# Instantiate a pretrained EnCodec model
|
||||||
|
assert cfg.sample_rate == 24_000
|
||||||
model = EncodecModel.encodec_model_24khz()
|
model = EncodecModel.encodec_model_24khz()
|
||||||
model.set_target_bandwidth(6.0)
|
model.set_target_bandwidth(6.0)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def unload_model():
|
||||||
|
return _load_model.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def decode(codes: Tensor, device="cuda"):
|
def decode(codes: Tensor, device="cuda"):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
codes: (b k t)
|
codes: (b q t)
|
||||||
"""
|
"""
|
||||||
assert codes.dim() == 3
|
assert codes.dim() == 3
|
||||||
model = _load_model(device)
|
model = _load_model(device)
|
||||||
return model.decode([(codes, None)]), model.sample_rate
|
return model.decode([(codes, None)]), model.sample_rate
|
||||||
|
|
||||||
|
|
||||||
def replace_file_extension(path, suffix):
|
def decode_to_file(resps: Tensor, path: Path):
|
||||||
|
assert resps.dim() == 2, f"Require shape (t q), but got {resps.shape}."
|
||||||
|
resps = rearrange(resps, "t q -> 1 q t")
|
||||||
|
wavs, sr = decode(resps)
|
||||||
|
soundfile.write(str(path), wavs.cpu()[0, 0], sr)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_file_extension(path, suffix):
|
||||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +62,7 @@ def encode(wav, sr, device="cuda"):
|
||||||
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||||
wav = wav.to(device)
|
wav = wav.to(device)
|
||||||
encoded_frames = model.encode(wav)
|
encoded_frames = model.encode(wav)
|
||||||
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b k t)
|
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
|
||||||
return qnt
|
return qnt
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,7 +75,7 @@ def main():
|
||||||
paths = [*args.folder.rglob(f"*{args.suffix}")]
|
paths = [*args.folder.rglob(f"*{args.suffix}")]
|
||||||
|
|
||||||
for path in tqdm(paths):
|
for path in tqdm(paths):
|
||||||
out_path = replace_file_extension(path, ".qnt.pt")
|
out_path = _replace_file_extension(path, ".qnt.pt")
|
||||||
wav, sr = torchaudio.load(path)
|
wav, sr = torchaudio.load(path)
|
||||||
if wav.shape[0] == 2:
|
if wav.shape[0] == 2:
|
||||||
wav = wav[:1]
|
wav = wav[:1]
|
||||||
|
|
48
vall_e/sampler.py
Normal file
48
vall_e/sampler.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
"""
|
||||||
|
A sampler that balances data by key_fns.
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Zhe Niu
|
||||||
|
|
||||||
|
niuzhe.nz@outlook.com
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler:
|
||||||
|
def __init__(self, l, key_fns):
|
||||||
|
self.tree = self._build(l, key_fns)
|
||||||
|
|
||||||
|
def _build(self, l, key_fns) -> dict[dict, list]:
|
||||||
|
if not key_fns:
|
||||||
|
return l
|
||||||
|
|
||||||
|
tree = {}
|
||||||
|
|
||||||
|
key_fn, *key_fns = key_fns
|
||||||
|
|
||||||
|
for x in l:
|
||||||
|
k = key_fn(x)
|
||||||
|
|
||||||
|
if k in tree:
|
||||||
|
tree[k].append(x)
|
||||||
|
else:
|
||||||
|
tree[k] = [x]
|
||||||
|
|
||||||
|
for k in tree:
|
||||||
|
tree[k] = self._build(tree[k], key_fns)
|
||||||
|
|
||||||
|
return tree
|
||||||
|
|
||||||
|
def _sample(self, tree: dict | list):
|
||||||
|
if isinstance(tree, list):
|
||||||
|
ret = random.choice(tree)
|
||||||
|
else:
|
||||||
|
key = random.choice([*tree.keys()])
|
||||||
|
ret = self._sample(tree[key])
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
return self._sample(self.tree)
|
137
vall_e/train.py
Normal file
137
vall_e/train.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .config import cfg
|
||||||
|
from .data import create_train_val_dataloader
|
||||||
|
from .emb import qnt
|
||||||
|
from .utils import setup_logging, to_device, trainer
|
||||||
|
from .vall_e import AR, NAR
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_engines():
|
||||||
|
if cfg.model.lower() == "ar":
|
||||||
|
model = AR(
|
||||||
|
cfg.num_tokens,
|
||||||
|
cfg.d_model,
|
||||||
|
cfg.n_heads,
|
||||||
|
cfg.n_layers,
|
||||||
|
cfg.p_dropout,
|
||||||
|
)
|
||||||
|
elif cfg.model.lower() == "nar":
|
||||||
|
model = NAR(
|
||||||
|
cfg.num_tokens,
|
||||||
|
cfg.d_model,
|
||||||
|
cfg.n_heads,
|
||||||
|
cfg.n_layers,
|
||||||
|
cfg.p_dropout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(cfg.model)
|
||||||
|
|
||||||
|
engines = dict(
|
||||||
|
model=trainer.Engine(
|
||||||
|
model=model,
|
||||||
|
config=cfg.ds_cfg,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return trainer.load_engines(engines, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
setup_logging(cfg.log_dir)
|
||||||
|
|
||||||
|
train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader()
|
||||||
|
|
||||||
|
def train_feeder(engines, batch, name):
|
||||||
|
model = engines["model"]
|
||||||
|
|
||||||
|
if cfg.model == "ar":
|
||||||
|
_ = model(
|
||||||
|
text_list=batch["text"],
|
||||||
|
proms_list=batch["proms"],
|
||||||
|
resp_list=batch["resp"],
|
||||||
|
)
|
||||||
|
elif cfg.model == "nar":
|
||||||
|
_ = model(
|
||||||
|
text_list=batch["text"],
|
||||||
|
proms_list=batch["proms"],
|
||||||
|
resps_list=batch["resps"],
|
||||||
|
)
|
||||||
|
|
||||||
|
losses = model.gather_attribute("loss")
|
||||||
|
|
||||||
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
|
stats = {}
|
||||||
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
|
stats |= engines.gather_attribute("scalar")
|
||||||
|
|
||||||
|
return loss, stats
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_eval(engines, name, dl):
|
||||||
|
log_dir = cfg.log_dir / str(engines.global_step) / name
|
||||||
|
|
||||||
|
model = engines["model"]
|
||||||
|
log_dir = cfg.log_dir / str(engines.global_step) / name
|
||||||
|
stats = defaultdict(list)
|
||||||
|
for batch in tqdm(dl):
|
||||||
|
batch: dict
|
||||||
|
batch = to_device(batch, cfg.device)
|
||||||
|
|
||||||
|
if cfg.model == "ar":
|
||||||
|
resp_list = model(text_list=batch["text"], proms_list=batch["proms"])
|
||||||
|
resps_list = [r.unsqueeze(-1) for r in resp_list]
|
||||||
|
elif cfg.model == "nar":
|
||||||
|
resps_list = model(
|
||||||
|
text_list=batch["text"],
|
||||||
|
proms_list=batch["proms"],
|
||||||
|
resp_list=batch["resp"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(cfg.model)
|
||||||
|
|
||||||
|
losses = model.gather_attribute("loss")
|
||||||
|
batch_stats = {k: v.item() for k, v in losses.items()}
|
||||||
|
for k, v in batch_stats.items():
|
||||||
|
stats[k].append(v)
|
||||||
|
|
||||||
|
for path, ref, hyp in zip(batch["path"], batch["resps"], resps_list):
|
||||||
|
relpath = path.relative_to(cfg.data_root)
|
||||||
|
hyp_path = (log_dir / "hyp" / relpath).with_suffix(".wav")
|
||||||
|
ref_path = (log_dir / "ref" / relpath).with_suffix(".wav")
|
||||||
|
hyp_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
qnt.decode_to_file(ref, ref_path)
|
||||||
|
if len(hyp) > 0:
|
||||||
|
qnt.decode_to_file(hyp, hyp_path)
|
||||||
|
|
||||||
|
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||||
|
stats["global_step"] = engines.global_step
|
||||||
|
stats["name"] = name
|
||||||
|
_logger.info(f"Eval: {stats}.")
|
||||||
|
|
||||||
|
_logger.info(f"{json.dumps(stats)}.")
|
||||||
|
|
||||||
|
def eval_fn(engines):
|
||||||
|
run_eval(engines, "train200", train200_dl)
|
||||||
|
run_eval(engines, "val", val_dl)
|
||||||
|
run_eval(engines, "test", test_dl)
|
||||||
|
|
||||||
|
trainer.train(
|
||||||
|
engines_loader=load_engines,
|
||||||
|
train_dl=train_dl,
|
||||||
|
train_feeder=train_feeder,
|
||||||
|
eval_fn=eval_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user