From 31f71fa134014095b0a4389bdc99c740c92c85ff Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 14 Jun 2024 16:55:40 -0500 Subject: [PATCH] sampler update (some brainworm just never actually had a sampler for sample_type=path) --- vall_e/data.py | 71 ++++++++++++++++++++++------------ vall_e/utils/sampler.py | 85 +++++++++++++++++++++++++++++++++++++++-- vall_e/utils/trainer.py | 4 +- 3 files changed, 129 insertions(+), 31 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 793b299..3fd0252 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -12,7 +12,7 @@ import itertools from .config import cfg from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file -from .utils.sampler import Sampler +from .utils.sampler import PoolSampler, OrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size from collections import defaultdict @@ -424,6 +424,7 @@ class Dataset(_Dataset): ): super().__init__() self._head = None + self.shuffle = False self.sampler = None self.paths = [] @@ -503,7 +504,6 @@ class Dataset(_Dataset): # just interleave self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] - self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() } # dict of speakers keyed by speaker group self.spkrs_by_spkr_group = {} @@ -521,8 +521,6 @@ class Dataset(_Dataset): self.spkr_groups = list(self.spkrs_by_spkr_group.keys()) - self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() } - self.noise_paths = _load_paths(cfg.dataset.noise, "noise") self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values())) @@ -539,6 +537,20 @@ class Dataset(_Dataset): if len(self.paths) == 0: raise ValueError(f"No valid path is found for {self.dataset_type}") + + sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" + + if self.sampler_type == "path": + self.sampler = OrderedSampler( len(self) ) + self.samplers = {} + self.spkr_samplers = {} + else: + self.sampler = RandomSampler( len(self) ) + self.samplers = { name: PoolSampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() } + self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() } + + self.load_state_dict() + def get_speaker(self, path): if isinstance(path, str): path = Path(path) @@ -568,21 +580,39 @@ class Dataset(_Dataset): def tasks(self): return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse" - def save_state_dict(self, path): - state_dict = { - "samplers": { name: sampler.current_pool for name, sampler in self.samplers.items() } - } + def save_state_dict(self, path = None): + if path is None: + path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" + + if self.sampler_type == "path": + state_dict = self.sampler.get_state() + else: + state_dict = { + "samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() }, + "spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() }, + } torch.save(state_dict, path) - def load_state_dict(self, path): - state_dict = torch.load(path) + def load_state_dict(self, path = None): + if path is None: + path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" - if "samplers" in state_dict: - # better than naively setting the entire object + if not path.exists(): + return + + state_dict = torch.load(path) + if self.sampler_type == "path": + state_dict = self.sampler.load_state(state_dict) + else: for name, sampler in state_dict["samplers"].items(): if name not in self.samplers: continue - self.samplers[name].current_pool = sampler + self.samplers[name].load_state( sampler ) + + for name, sampler in state_dict["spkr_samplers"].items(): + if name not in self.spkr_samplers: + continue + self.spkr_samplers[name].load_state( sampler ) def _get_phone_symmap(self): return get_phone_symmap() @@ -965,36 +995,29 @@ def _seed_worker(worker_id): def _create_dataloader(dataset, training): - sampler = None - shuffle = True - """ if cfg.distributed and training: sampler = DistributedSampler(dataset) shuffle = False - """ + """ return DataLoader( dataset=dataset, batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, - shuffle=shuffle, + shuffle=dataset.shuffle, drop_last=training, num_workers=cfg.dataset.workers, collate_fn=collate_fn, persistent_workers=cfg.dataset.workers > 1, pin_memory=False, # True, worker_init_fn=_seed_worker, - sampler=sampler, + sampler=dataset.sampler, ) def create_datasets(): train_dataset = Dataset( training=True ) val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False ) - train_state_path = cfg.rel_path / f"sampler.rank{global_rank()}.pt" - if train_state_path.exists(): - train_dataset.load_state_dict( train_state_path ) - return train_dataset, val_dataset @@ -1312,8 +1335,6 @@ if __name__ == "__main__": for i in range(len(v)): print(f'{k}[{i}]:', v[i]) - #train_dl.dataset.save_state_dict(cfg.rel_path / "train_dataset.pt") - elif args.action == "tasks": index = 0 cfg.dataset.tasks_list = args.tasks.split(",") diff --git a/vall_e/utils/sampler.py b/vall_e/utils/sampler.py index 59e2ab0..b584bd5 100644 --- a/vall_e/utils/sampler.py +++ b/vall_e/utils/sampler.py @@ -2,11 +2,15 @@ from dataclasses import dataclass from typing import Any import random -@dataclass -class Sampler(): +import torch +from torch.utils.data import Sampler + +# Randomly picks an index from an array of indices +class PoolSampler(): def __init__( self, pool = [], keep_all = False ): + self.length = len(pool) self.global_pool = pool if keep_all else None - self.global_indices = [ i for i in range(len(pool)) ] + self.global_indices = [ i for i in range(self.length) ] self.reset() def reset(self): @@ -25,5 +29,78 @@ class Sampler(): # map indices to our real values return pool[index] if pool is not None else index + def __len__(self): + return self.length # len(self.current_pool) + + def __iter__(self): + while len(self.current_pool) > 0: + yield self.sample() + def __call__(self, *args, **kwargs): - return self.sample(*args, **kwargs) \ No newline at end of file + return self.sample(*args, **kwargs) + + def get_state(self): + return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool } + + def set_state(self, state): + self.length = state["length"] + self.global_pool = state["global_pool"] + self.global_indices = state["global_indices"] + self.current_pool = state["current_pool"] + +# "Samples" through a fixed sequence from 0 to length +# Necessary for our "shuffle+sort by duration+interleave" sampling method +# Allows saving and loading state +class OrderedSampler(Sampler): + def __init__( self, length ): + self.position = 0 + self.length = length + + def __len__(self): + return self.length + + def __iter__(self): + if self.position >= self.length: + self.position = 0 + + while self.position < self.length: + yield self.position + self.position += 1 + + def get_state(self): + return { "position": self.position, "length": self.length } + + def set_state(self, state): + self.position = state["position"] + self.length = state["length"] + +# Randomly samples indices from a given sequence from 0 to length +# Allows saving and loading state +class RandomSampler(Sampler): + def __init__( self, length ): + self.position = 0 + self.length = length + + self.generator = torch.Generator() + self.perm = torch.randperm(self.length, generator=self.generator) + + def __len__(self): + return self.length + + def __iter__(self): + if self.position >= self.length: + self.position = 0 + self.perm = torch.randperm(self.length, generator=self.generator) + + while self.position < self.length: + yield self.perm[self.position] + self.position += 1 + + def get_state(self): + return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() } + + def set_state(self, state): + self.position = state["position"] + self.length = state["length"] + self.perm = state["perm"] + self.generator.set_state(state["generator"]) \ No newline at end of file diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 909832a..d5af171 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -218,7 +218,7 @@ def train( print("Failed to set LR rate to:", rate, str(e)) if "export" in command: - train_dl.dataset.save_state_dict(cfg.rel_path / f"sampler.rank{global_rank()}.pt") + train_dl.dataset.save_state_dict() engines.save_checkpoint() last_save_step = engines.global_step @@ -241,7 +241,7 @@ def train( if engines.global_step != last_save_step: if engines.global_step % save_ckpt_every == 0 or command in saving_commands: - train_dl.dataset.save_state_dict(cfg.rel_path / f"sampler.rank{global_rank()}.pt") + train_dl.dataset.save_state_dict() engines.save_checkpoint() last_save_step = engines.global_step