From 8a6c20327757fe8fd42ac17a135f3b180a9bc37a Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 3 Sep 2023 21:27:13 -0500 Subject: [PATCH] added per-speaker samplers --- .gitignore | 4 +++- vall_e/config.py | 1 + vall_e/data.py | 27 ++++++++++++++++++++++++++- vall_e/models/__init__.py | 1 + vall_e/models/ar.py | 12 ++++++------ vall_e/models/nar.py | 6 +++--- vall_e/utils/sampler.py | 29 ++++++++++++++++++++++++++++- vall_e/utils/trainer.py | 2 ++ 8 files changed, 70 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index b6dd66a..2cd1480 100755 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,6 @@ __pycache__ /*.egg-info /vall_e/version.py /build -/.cache \ No newline at end of file +/.cache + +/vall_e/ext/interleaver.py \ No newline at end of file diff --git a/vall_e/config.py b/vall_e/config.py index 8474189..2afac34 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -162,6 +162,7 @@ class Model: tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") arch_type: str = "transformer" training: bool = True + interleave_pattern: str | None = None @property def full_name(self): diff --git a/vall_e/data.py b/vall_e/data.py index 7e381d5..c2d8a48 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -12,6 +12,7 @@ import itertools from .config import cfg from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file +from .utils.sampler import Sampler from collections import defaultdict from functools import cache, cached_property @@ -173,6 +174,8 @@ class Dataset(_Dataset): self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type) self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values())) + self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() } + if cfg.dataset.sample_type == "path": self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] @@ -215,6 +218,22 @@ class Dataset(_Dataset): def tasks(self): return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse" + def save_state_dict(self, path): + state_dict = { + "samplers": { name: sampler.current_pool for name, sampler in self.samplers.items() } + } + torch.save(state_dict, path) + + def load_state_dict(self, path): + state_dict = torch.load(path) + + if "samplers" in state_dict: + # better than naively setting the entire object + for name, sampler in state_dict["samplers"].items(): + if name not in self.samplers: + continue + self.samplers[name].current_pool = sampler + def _get_phone_symmap(self): return get_phone_symmap() @@ -290,7 +309,7 @@ class Dataset(_Dataset): if cfg.dataset.sample_type == "speaker": spkr_name = self.spkrs[index] spkr_id = self.spkr_symmap[spkr_name] - path = random.choice([*set(self.paths_by_spkr_name[spkr_name])]) + path = self.samplers[spkr_name].sample() else: path = self.paths[index] spkr_name = self.get_speaker(path) @@ -543,6 +562,10 @@ def create_datasets(): train_dataset = Dataset( training=True ) val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False ) + train_state_path = cfg.relpath / "train_dataset.pt" + if train_state_path.exists(): + train_dataset.load_state_dict( train_state_path ) + return train_dataset, val_dataset @@ -752,6 +775,8 @@ if __name__ == "__main__": del v[i]['resps'] print(f'{k}:', v) + train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt") + elif args.action == "tasks": index = 0 cfg.dataset.tasks_list = args.tasks.split(",") diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index a1d4b46..16355cc 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -15,6 +15,7 @@ def get_model(cfg): d_model=cfg.dim, n_heads=cfg.heads, n_layers=cfg.layers, + config = cfg ) model._cfg = cfg diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index dcf1e72..363c94e 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -22,8 +22,8 @@ class AR(Base): return "ln" @property - def arch_type(self) -> bool: - if hasattr(self, "_cfg"): + def arch_type(self) -> str: + if hasattr(self, "_cfg") and self._cfg: return self._cfg.arch_type return cfg.models.ar.arch_type @@ -33,7 +33,7 @@ class AR(Base): @property def n_resp_levels(self) -> int: - if hasattr(self, "_cfg"): + if hasattr(self, "_cfg") and self._cfg: return self._cfg.resp_levels return cfg.models.ar.resp_levels @@ -146,8 +146,8 @@ def example_usage(): tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), ] proms_list = [ - x8(torch.tensor([1, 2, 3], device=device)), - #qnt.to(device), + #x8(torch.tensor([1, 2, 3], device=device)), + qnt.to(device), ] resps_list = [ qnt.to(device), @@ -161,7 +161,7 @@ def example_usage(): 'n_tokens': 1024, 'd_model': 1024, 'n_heads': 16, - 'n_layers': 12, + 'n_layers': 24, } model = AR(**kwargs).to(device) engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index aed1bf9..531720b 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -16,8 +16,8 @@ class NAR(Base): return False @property - def arch_type(self) -> bool: - if hasattr(self, "_cfg"): + def arch_type(self) -> str: + if hasattr(self, "_cfg") and self._cfg: return self._cfg.arch_type return cfg.models.nar.arch_type @@ -31,7 +31,7 @@ class NAR(Base): @property def n_resp_levels(self) -> int: - if hasattr(self, "_cfg"): + if hasattr(self, "_cfg") and self._cfg: return self._cfg.resp_levels return cfg.models.nar.resp_levels diff --git a/vall_e/utils/sampler.py b/vall_e/utils/sampler.py index ad22dec..59e2ab0 100644 --- a/vall_e/utils/sampler.py +++ b/vall_e/utils/sampler.py @@ -1,2 +1,29 @@ +from dataclasses import dataclass +from typing import Any +import random + +@dataclass class Sampler(): - ... \ No newline at end of file + def __init__( self, pool = [], keep_all = False ): + self.global_pool = pool if keep_all else None + self.global_indices = [ i for i in range(len(pool)) ] + self.reset() + + def reset(self): + self.current_pool = [ i for i in self.global_indices ] + + def sample(self, pool = None): + if pool is None: + pool = self.global_pool + # check if we need to reset + index = random.choice( self.current_pool ) + # remove from pool + self.current_pool.remove(index) + # reset if needed + if len(self.current_pool) == 0: + self.reset() + # map indices to our real values + return pool[index] if pool is not None else index + + def __call__(self, *args, **kwargs): + return self.sample(*args, **kwargs) \ No newline at end of file diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index cfb4c7c..10fcca5 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -311,6 +311,7 @@ def train( print("Failed to set LR rate to:", rate, str(e)) if "export" in command: + train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt") engines.save_checkpoint() last_save_step = engines.global_step @@ -333,6 +334,7 @@ def train( if engines.global_step != last_save_step: if engines.global_step % save_ckpt_every == 0 or command in saving_commands: + train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt") engines.save_checkpoint() last_save_step = engines.global_step