Add trainer
This commit is contained in:
parent
d19449f1f1
commit
c3bacebfab
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,3 +2,4 @@ __pycache__
|
|||
/data
|
||||
/logs
|
||||
/ckpts
|
||||
/.cache
|
||||
|
|
|
@ -2,9 +2,12 @@
|
|||
|
||||
An unofficial (toy) implementation of VALL-E, based on the [encodec](https://github.com/facebookresearch/encodec) tokenizer.
|
||||
|
||||
[](https://www.buymeacoffee.com/enhuiz)
|
||||
|
||||
## TODO
|
||||
|
||||
- [x] AR model for the first quantizer.
|
||||
- [x] Audio decoding from tokens.
|
||||
- [x] NAR model for the rest quantizers.
|
||||
- [ ] Trainers for both models.
|
||||
- [x] Trainers for both models.
|
||||
- [ ] Pre-trained checkpoint.
|
||||
|
|
4
config/ar.yml
Normal file
4
config/ar.yml
Normal file
|
@ -0,0 +1,4 @@
|
|||
data_dirs: [data/test]
|
||||
|
||||
model: ar
|
||||
batch_size: 1
|
4
config/nar.yml
Normal file
4
config/nar.yml
Normal file
|
@ -0,0 +1,4 @@
|
|||
data_dirs: [data/test]
|
||||
|
||||
model: nar
|
||||
batch_size: 1
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
77
vall_e/config.py
Normal file
77
vall_e/config.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
|
||||
import diskcache
|
||||
|
||||
from .utils import Config as ConfigBase
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Config(ConfigBase):
|
||||
data_root: Path = Path("data")
|
||||
data_dirs: list[Path] = field(default_factory=lambda: [])
|
||||
test_data_dirs: list[Path] = field(default_factory=lambda: [])
|
||||
|
||||
batch_size: int = 24
|
||||
eval_batch_size: int = 12
|
||||
nj: int = 8
|
||||
|
||||
@property
|
||||
def sample_rate(self):
|
||||
return 24_000
|
||||
|
||||
p_additional_prompt: float = 0.5
|
||||
|
||||
token_dim: int = 256
|
||||
num_tokens: int = 1024
|
||||
|
||||
batch_size: int = 128
|
||||
eval_batch_size: int = 512
|
||||
warmup_min_lr: float = 1e-6
|
||||
warmup_max_lr: float = 2e-4
|
||||
dis_warmup_max_lr: float = 4e-4
|
||||
warmup_num_steps: int = 1_000
|
||||
max_iter: int = 10_000
|
||||
gradient_clipping: float = 100
|
||||
eval_every: int = 2_000
|
||||
save_ckpt_every: int = 10_000
|
||||
|
||||
model: str = "ar"
|
||||
d_model: int = 512
|
||||
n_heads: int = 8
|
||||
n_layers: int = 12
|
||||
p_dropout: float = 0.1
|
||||
|
||||
@property
|
||||
def ds_cfg(self):
|
||||
return {
|
||||
"train_micro_batch_size_per_gpu": self.batch_size,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"optimizer": {"type": "Adam"},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": self.warmup_min_lr,
|
||||
"warmup_max_lr": self.warmup_max_lr,
|
||||
"warmup_num_steps": self.warmup_num_steps,
|
||||
"total_num_steps": self.max_iter,
|
||||
"warmup_type": "linear",
|
||||
},
|
||||
},
|
||||
"gradient_clipping": self.gradient_clipping,
|
||||
}
|
||||
|
||||
@property
|
||||
def cache_dir(self):
|
||||
return ".cache" / self.relpath
|
||||
|
||||
@cached_property
|
||||
def diskcache(self):
|
||||
return diskcache.Cache(self.cache_dir).memoize
|
||||
|
||||
|
||||
cfg = Config.from_cli()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(cfg)
|
294
vall_e/data.py
Normal file
294
vall_e/data.py
Normal file
|
@ -0,0 +1,294 @@
|
|||
import copy
|
||||
import logging
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from functools import cache, cached_property
|
||||
from itertools import groupby, zip_longest
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import cfg
|
||||
from .sampler import Sampler
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _replace_file_extension(path, suffix):
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
|
||||
def _get_quant_path(path):
|
||||
return _replace_file_extension(path, ".qnt.pt")
|
||||
|
||||
|
||||
def _load_quants(path) -> Tensor:
|
||||
"""
|
||||
Returns:
|
||||
quants: (t q)
|
||||
"""
|
||||
path = _get_quant_path(path)
|
||||
return torch.load(path)[0].t()
|
||||
|
||||
|
||||
@cache
|
||||
def _get_phones(path):
|
||||
path = _replace_file_extension(path, ".phn.txt")
|
||||
with open(path, "r", encoding="utf8") as f:
|
||||
content = f.read()
|
||||
return ["<s>"] + content.split() + ["</s>"]
|
||||
|
||||
|
||||
def _interleaved_reorder(l, fn):
|
||||
groups = defaultdict(list)
|
||||
for e in l:
|
||||
groups[fn(e)].append(e)
|
||||
groups = {k: groups[k] for k in sorted(groups)}
|
||||
for interleaved in zip_longest(*groups.values()):
|
||||
for value in interleaved:
|
||||
if value is not None:
|
||||
yield value
|
||||
|
||||
|
||||
@cache
|
||||
def _validate(path, min_phones, max_phones):
|
||||
phones = _get_phones(path)
|
||||
unique_phones = list(set(phones))
|
||||
if len(unique_phones) == 0:
|
||||
return False
|
||||
if len(unique_phones) == 1 and unique_phones[0] == "_":
|
||||
return False
|
||||
if len(phones) < min_phones:
|
||||
return False
|
||||
if len(phones) > max_phones:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_spkr_name(path) -> str:
|
||||
return path.parts[-2] # spkr/*.wav
|
||||
|
||||
|
||||
class VALLEDatset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
paths,
|
||||
phone_symmap=None,
|
||||
spkr_symmap=None,
|
||||
min_phones=10,
|
||||
max_phones=100,
|
||||
training=False,
|
||||
extra_paths_by_spkr_name: dict[str, list] = {},
|
||||
):
|
||||
super().__init__()
|
||||
self._head = None
|
||||
self.min_phones = min_phones
|
||||
self.max_phones = max_phones
|
||||
self.paths = [
|
||||
path for path in paths if _validate(path, self.min_phones, self.max_phones)
|
||||
]
|
||||
self.spkr_symmap = spkr_symmap or self._get_spkr_symmap()
|
||||
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
||||
self.training = training
|
||||
|
||||
self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
|
||||
|
||||
if training:
|
||||
self.sampler = Sampler(self.paths, [_get_spkr_name])
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
|
||||
ret = defaultdict(list)
|
||||
for path in self.paths:
|
||||
if _get_quant_path(path).exists():
|
||||
ret[_get_spkr_name(path)].append(path)
|
||||
for k, v in extra_paths_by_spkr_name.items():
|
||||
ret[k].extend(v)
|
||||
return {**ret}
|
||||
|
||||
@cached_property
|
||||
def phones(self):
|
||||
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
||||
|
||||
def _get_phone_symmap(self):
|
||||
# Note that we use phone symmap starting from 1 so that we can safely pad 0.
|
||||
return {s: i for i, s in enumerate(self.phones, 1)}
|
||||
|
||||
@cached_property
|
||||
def spkrs(self):
|
||||
return sorted({_get_spkr_name(path) for path in self.paths})
|
||||
|
||||
def _get_spkr_symmap(self):
|
||||
return {s: i for i, s in enumerate(self.spkrs)}
|
||||
|
||||
def sample_prompts(self, spkr_name):
|
||||
prom_list = []
|
||||
|
||||
while (
|
||||
len(prom_list) == 0
|
||||
or random.random() < cfg.p_additional_prompt
|
||||
and len(prom_list) < 10
|
||||
):
|
||||
path = random.choice(self.paths_by_spkr_name[spkr_name])
|
||||
prom_list.append(_load_quants(path))
|
||||
|
||||
prom = torch.cat(prom_list)
|
||||
|
||||
return prom
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.training:
|
||||
assert self.sampler is not None
|
||||
path = self.sampler.sample()
|
||||
else:
|
||||
path = self.paths[index]
|
||||
|
||||
spkr_name = _get_spkr_name(path)
|
||||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))])
|
||||
proms = self.sample_prompts(spkr_name)
|
||||
resps = _load_quants(path)
|
||||
resp = resps[..., 0]
|
||||
|
||||
return dict(
|
||||
path=path,
|
||||
spkr_name=spkr_name,
|
||||
text=text,
|
||||
proms=proms,
|
||||
resps=resps,
|
||||
resp=resp,
|
||||
)
|
||||
|
||||
def head_(self, n):
|
||||
self._head = n
|
||||
|
||||
def training_(self, value):
|
||||
self.training = value
|
||||
|
||||
def interleaved_reorder_(self, fn):
|
||||
self.paths = [*_interleaved_reorder(self.paths, fn)]
|
||||
|
||||
def __len__(self):
|
||||
return min(len(self.paths), self._head or len(self.paths))
|
||||
|
||||
|
||||
def collate_fn(samples: list[dict]):
|
||||
batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
|
||||
return batch
|
||||
|
||||
|
||||
def _seed_worker(worker_id):
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def _create_dl(dataset, training):
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.batch_size if training else cfg.eval_batch_size,
|
||||
shuffle=training,
|
||||
drop_last=training,
|
||||
num_workers=cfg.nj,
|
||||
collate_fn=collate_fn,
|
||||
persistent_workers=True,
|
||||
worker_init_fn=_seed_worker,
|
||||
)
|
||||
|
||||
|
||||
def _load_train_val_paths():
|
||||
paths = []
|
||||
train_paths = []
|
||||
val_paths = []
|
||||
|
||||
for data_dir in cfg.data_dirs:
|
||||
paths.extend(tqdm(data_dir.rglob("**/*.qnt.pt")))
|
||||
|
||||
if len(paths) == 0:
|
||||
raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.data_dirs}.")
|
||||
|
||||
pairs = sorted([(_get_spkr_name(p), p) for p in paths])
|
||||
del paths
|
||||
|
||||
for _, group in groupby(pairs, lambda pair: pair[0]):
|
||||
paths = sorted([p for _, p in group])
|
||||
random.seed(0)
|
||||
random.shuffle(paths)
|
||||
n = round(len(paths) * 0.95)
|
||||
train_paths.extend(paths[:n])
|
||||
val_paths.extend(paths[n:])
|
||||
|
||||
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
||||
|
||||
return train_paths, val_paths
|
||||
|
||||
|
||||
def _load_test_paths():
|
||||
test_paths = []
|
||||
for data_dir in cfg.test_data_dirs:
|
||||
test_paths.extend(data_dir.rglob("**/*.asr.txt"))
|
||||
test_paths = sorted(test_paths)
|
||||
return test_paths
|
||||
|
||||
|
||||
@cfg.diskcache()
|
||||
def create_datasets():
|
||||
train_paths, val_paths = _load_train_val_paths()
|
||||
test_paths = _load_test_paths()
|
||||
|
||||
train_dataset = VALLEDatset(train_paths, training=True)
|
||||
|
||||
val_dataset = VALLEDatset(
|
||||
val_paths,
|
||||
train_dataset.phone_symmap,
|
||||
train_dataset.spkr_symmap,
|
||||
extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
||||
)
|
||||
|
||||
val_dataset.interleaved_reorder_(_get_spkr_name)
|
||||
val_dataset.head_(200)
|
||||
|
||||
test_dataset = VALLEDatset(
|
||||
test_paths,
|
||||
train_dataset.phone_symmap,
|
||||
train_dataset.spkr_symmap,
|
||||
extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
||||
)
|
||||
|
||||
return train_dataset, val_dataset, test_dataset
|
||||
|
||||
|
||||
def create_train_val_dataloader():
|
||||
train_dataset, val_dataset, test_dataset = create_datasets()
|
||||
|
||||
train_dl = _create_dl(train_dataset, training=True)
|
||||
val_dl = _create_dl(val_dataset, training=False)
|
||||
test_dl = _create_dl(test_dataset, training=False)
|
||||
|
||||
_logger.info(str(train_dataset.phone_symmap))
|
||||
_logger.info(str(train_dataset.spkr_symmap))
|
||||
|
||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||
_logger.info(f"#samples (test): {len(test_dataset)}.")
|
||||
|
||||
train200_dataset = copy.deepcopy(train_dataset)
|
||||
train200_dataset.interleaved_reorder_(_get_spkr_name)
|
||||
train200_dataset.head_(200)
|
||||
train200_dataset.training_(False)
|
||||
train200_dl = _create_dl(train200_dataset, training=False)
|
||||
assert isinstance(train200_dl.dataset, VALLEDatset)
|
||||
|
||||
return train_dl, train200_dl, val_dl, test_dl
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader()
|
||||
sample = train_dl.dataset[0]
|
||||
print(sample)
|
50
vall_e/emb/g2p.py
Normal file
50
vall_e/emb/g2p.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import argparse
|
||||
import random
|
||||
import string
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from g2p_en import G2p
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@cache
|
||||
def _get_model():
|
||||
return G2p()
|
||||
|
||||
|
||||
@cache
|
||||
def _get_graphs(path):
|
||||
with open(path, "r") as f:
|
||||
graphs = f.read()
|
||||
return graphs
|
||||
|
||||
|
||||
def encode(graphs: str) -> list[str]:
|
||||
g2p = _get_model()
|
||||
phones = g2p(graphs)
|
||||
ignored = {" ", *string.punctuation}
|
||||
return ["_" if p in ignored else p for p in phones]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("folder", type=Path)
|
||||
parser.add_argument("--suffix", type=str, default=".normalized.txt")
|
||||
args = parser.parse_args()
|
||||
|
||||
paths = list(args.folder.rglob(f"*{args.suffix}"))
|
||||
random.shuffle(paths)
|
||||
|
||||
for path in tqdm(paths):
|
||||
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
|
||||
graphs = _get_graphs(path)
|
||||
phones = encode(graphs)
|
||||
with open(phone_path, "w") as f:
|
||||
f.write(" ".join(phones))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -2,35 +2,51 @@ import argparse
|
|||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile
|
||||
import torch
|
||||
import torchaudio
|
||||
from einops import rearrange
|
||||
from encodec import EncodecModel
|
||||
from encodec.utils import convert_audio
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..config import cfg
|
||||
|
||||
|
||||
@cache
|
||||
def _load_model(device="cuda"):
|
||||
# Instantiate a pretrained EnCodec model
|
||||
assert cfg.sample_rate == 24_000
|
||||
model = EncodecModel.encodec_model_24khz()
|
||||
model.set_target_bandwidth(6.0)
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
def unload_model():
|
||||
return _load_model.cache_clear()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(codes: Tensor, device="cuda"):
|
||||
"""
|
||||
Args:
|
||||
codes: (b k t)
|
||||
codes: (b q t)
|
||||
"""
|
||||
assert codes.dim() == 3
|
||||
model = _load_model(device)
|
||||
return model.decode([(codes, None)]), model.sample_rate
|
||||
|
||||
|
||||
def replace_file_extension(path, suffix):
|
||||
def decode_to_file(resps: Tensor, path: Path):
|
||||
assert resps.dim() == 2, f"Require shape (t q), but got {resps.shape}."
|
||||
resps = rearrange(resps, "t q -> 1 q t")
|
||||
wavs, sr = decode(resps)
|
||||
soundfile.write(str(path), wavs.cpu()[0, 0], sr)
|
||||
|
||||
|
||||
def _replace_file_extension(path, suffix):
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
|
||||
|
@ -46,7 +62,7 @@ def encode(wav, sr, device="cuda"):
|
|||
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||
wav = wav.to(device)
|
||||
encoded_frames = model.encode(wav)
|
||||
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b k t)
|
||||
qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
|
||||
return qnt
|
||||
|
||||
|
||||
|
@ -59,7 +75,7 @@ def main():
|
|||
paths = [*args.folder.rglob(f"*{args.suffix}")]
|
||||
|
||||
for path in tqdm(paths):
|
||||
out_path = replace_file_extension(path, ".qnt.pt")
|
||||
out_path = _replace_file_extension(path, ".qnt.pt")
|
||||
wav, sr = torchaudio.load(path)
|
||||
if wav.shape[0] == 2:
|
||||
wav = wav[:1]
|
||||
|
|
48
vall_e/sampler.py
Normal file
48
vall_e/sampler.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
"""
|
||||
A sampler that balances data by key_fns.
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Zhe Niu
|
||||
|
||||
niuzhe.nz@outlook.com
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(self, l, key_fns):
|
||||
self.tree = self._build(l, key_fns)
|
||||
|
||||
def _build(self, l, key_fns) -> dict[dict, list]:
|
||||
if not key_fns:
|
||||
return l
|
||||
|
||||
tree = {}
|
||||
|
||||
key_fn, *key_fns = key_fns
|
||||
|
||||
for x in l:
|
||||
k = key_fn(x)
|
||||
|
||||
if k in tree:
|
||||
tree[k].append(x)
|
||||
else:
|
||||
tree[k] = [x]
|
||||
|
||||
for k in tree:
|
||||
tree[k] = self._build(tree[k], key_fns)
|
||||
|
||||
return tree
|
||||
|
||||
def _sample(self, tree: dict | list):
|
||||
if isinstance(tree, list):
|
||||
ret = random.choice(tree)
|
||||
else:
|
||||
key = random.choice([*tree.keys()])
|
||||
ret = self._sample(tree[key])
|
||||
return ret
|
||||
|
||||
def sample(self):
|
||||
return self._sample(self.tree)
|
137
vall_e/train.py
Normal file
137
vall_e/train.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import cfg
|
||||
from .data import create_train_val_dataloader
|
||||
from .emb import qnt
|
||||
from .utils import setup_logging, to_device, trainer
|
||||
from .vall_e import AR, NAR
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_engines():
|
||||
if cfg.model.lower() == "ar":
|
||||
model = AR(
|
||||
cfg.num_tokens,
|
||||
cfg.d_model,
|
||||
cfg.n_heads,
|
||||
cfg.n_layers,
|
||||
cfg.p_dropout,
|
||||
)
|
||||
elif cfg.model.lower() == "nar":
|
||||
model = NAR(
|
||||
cfg.num_tokens,
|
||||
cfg.d_model,
|
||||
cfg.n_heads,
|
||||
cfg.n_layers,
|
||||
cfg.p_dropout,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(cfg.model)
|
||||
|
||||
engines = dict(
|
||||
model=trainer.Engine(
|
||||
model=model,
|
||||
config=cfg.ds_cfg,
|
||||
),
|
||||
)
|
||||
|
||||
return trainer.load_engines(engines, cfg)
|
||||
|
||||
|
||||
def main():
|
||||
setup_logging(cfg.log_dir)
|
||||
|
||||
train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader()
|
||||
|
||||
def train_feeder(engines, batch, name):
|
||||
model = engines["model"]
|
||||
|
||||
if cfg.model == "ar":
|
||||
_ = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
resp_list=batch["resp"],
|
||||
)
|
||||
elif cfg.model == "nar":
|
||||
_ = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
resps_list=batch["resps"],
|
||||
)
|
||||
|
||||
losses = model.gather_attribute("loss")
|
||||
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= engines.gather_attribute("scalar")
|
||||
|
||||
return loss, stats
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_eval(engines, name, dl):
|
||||
log_dir = cfg.log_dir / str(engines.global_step) / name
|
||||
|
||||
model = engines["model"]
|
||||
log_dir = cfg.log_dir / str(engines.global_step) / name
|
||||
stats = defaultdict(list)
|
||||
for batch in tqdm(dl):
|
||||
batch: dict
|
||||
batch = to_device(batch, cfg.device)
|
||||
|
||||
if cfg.model == "ar":
|
||||
resp_list = model(text_list=batch["text"], proms_list=batch["proms"])
|
||||
resps_list = [r.unsqueeze(-1) for r in resp_list]
|
||||
elif cfg.model == "nar":
|
||||
resps_list = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
resp_list=batch["resp"],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(cfg.model)
|
||||
|
||||
losses = model.gather_attribute("loss")
|
||||
batch_stats = {k: v.item() for k, v in losses.items()}
|
||||
for k, v in batch_stats.items():
|
||||
stats[k].append(v)
|
||||
|
||||
for path, ref, hyp in zip(batch["path"], batch["resps"], resps_list):
|
||||
relpath = path.relative_to(cfg.data_root)
|
||||
hyp_path = (log_dir / "hyp" / relpath).with_suffix(".wav")
|
||||
ref_path = (log_dir / "ref" / relpath).with_suffix(".wav")
|
||||
hyp_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
qnt.decode_to_file(ref, ref_path)
|
||||
if len(hyp) > 0:
|
||||
qnt.decode_to_file(hyp, hyp_path)
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
stats["global_step"] = engines.global_step
|
||||
stats["name"] = name
|
||||
_logger.info(f"Eval: {stats}.")
|
||||
|
||||
_logger.info(f"{json.dumps(stats)}.")
|
||||
|
||||
def eval_fn(engines):
|
||||
run_eval(engines, "train200", train200_dl)
|
||||
run_eval(engines, "val", val_dl)
|
||||
run_eval(engines, "test", test_dl)
|
||||
|
||||
trainer.train(
|
||||
engines_loader=load_engines,
|
||||
train_dl=train_dl,
|
||||
train_feeder=train_feeder,
|
||||
eval_fn=eval_fn,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user