From 312a8e3ead1b9d35c3ce4efeb17aac5f033da1be Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 30 Jun 2024 11:36:46 -0500 Subject: [PATCH] add shuffle to samplers that can support it --- vall_e/config.py | 3 ++- vall_e/data.py | 52 +++++++++++++---------------------------- vall_e/utils/sampler.py | 13 +++++++++-- 3 files changed, 29 insertions(+), 39 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 0f08ac1..c0aa001 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -157,9 +157,10 @@ class Dataset: p_resp_append: float = 1.0 sample_type: str = "path" # path | speaker - sample_order: str = "shuffle" # duration + sample_order: str = "interleaved" # duration sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable # for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough + sample_shuffle: bool = True # tasks_list: list[str] = field(default_factory=lambda: ["tts"]) diff --git a/vall_e/data.py b/vall_e/data.py index 477112f..b2537c7 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -424,7 +424,6 @@ class Dataset(_Dataset): ): super().__init__() self._head = None - self.shuffle = False self.sampler = None self.paths = [] @@ -434,6 +433,7 @@ class Dataset(_Dataset): self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group" self.sampler_order = cfg.dataset.sample_order + self.sampler_shuffle = cfg.dataset.sample_shuffle # to-do: do not do validation if there's nothing in the validation # this just makes it be happy @@ -510,7 +510,7 @@ class Dataset(_Dataset): flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)] # flatten paths self.paths = list(itertools.chain.from_iterable(flattened.values())) - elif self.sampler_order == "shuffle": + else: # just interleave self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] @@ -547,29 +547,28 @@ class Dataset(_Dataset): if len(self.paths) == 0: raise ValueError(f"No valid path is found for {self.dataset_type}") - sampler_path = cfg.rel_path / self.sampler_state_dict_path - if self.sampler_type == "path": if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: self.sampler = BatchedOrderedSampler( - self.duration_buckets if not sampler_path.exists() else {}, # pass nothing if we're just going to load from a state anyways - cfg.dataset.sample_max_duration_batch, - cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size + self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways + max_duration=cfg.dataset.sample_max_duration_batch, + max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size, + shuffle=self.sampler_shuffle ) else: - self.sampler = OrderedSampler( len(self) ) + self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) ) self.samplers = {} self.spkr_samplers = {} else: self.sampler = RandomSampler( len(self) ) - self.samplers = { name: PoolSampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() } - self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() } + self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() } + self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() } self.load_state_dict() @cached_property def sampler_state_dict_path(self): - return f"sampler.{self.sampler_type}.rank{global_rank()}.pt" + return cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt" def get_speaker(self, path): if isinstance(path, str): @@ -602,7 +601,7 @@ class Dataset(_Dataset): def save_state_dict(self, path = None): if path is None: - path = cfg.rel_path / self.sampler_state_dict_path + path = self.sampler_state_dict_path if self.sampler_type == "path": state_dict = self.sampler.get_state() @@ -615,7 +614,7 @@ class Dataset(_Dataset): def load_state_dict(self, path = None): if path is None: - path = cfg.rel_path / self.sampler_state_dict_path + path = self.sampler_state_dict_path if not path.exists(): return @@ -652,13 +651,6 @@ class Dataset(_Dataset): def _get_task_symmap(self): return get_task_symmap() - """ - def get_task_token( self, token, levels=cfg.model.max_levels ): - if not hasattr(self, "task_symmap"): - self.task_symmap = self._get_task_symmap() - return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16) - """ - def sample_noise(self): path = random.choice(self.noise_paths) @@ -756,12 +748,13 @@ class Dataset(_Dataset): else: resps, metadata = _load_quants(path, return_metadata=True) text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype) - #text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype) lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8) # append additional prompts in an attempt to artifically increase lengths / offer new data """ + # disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works + # it probably is better to pad with silence instead of just stitching utterances and ruining things if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append: choices = [*(set(self.paths_by_spkr_name[spkr_name]) - {path})] @@ -997,13 +990,6 @@ class Dataset(_Dataset): return min(len(self.spkrs), self._head or len(self.spkrs)) return min(len(self.paths), self._head or len(self.paths)) - def pin_memory(self): - self.text = self.text.pin_memory() - self.proms = self.proms.pin_memory() - self.resps = self.resps.pin_memory() - self.resp = self.resp.pin_memory() - return self - def collate_fn(samples: list[dict]): batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]} @@ -1017,14 +1003,8 @@ def _seed_worker(worker_id): def _create_dataloader(dataset, training): - """ - if cfg.distributed and training: - sampler = DistributedSampler(dataset) - shuffle = False - """ - kwargs = dict( - shuffle=dataset.shuffle, + shuffle=False, batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, drop_last=training, sampler=dataset.sampler, @@ -1037,7 +1017,7 @@ def _create_dataloader(dataset, training): num_workers=cfg.dataset.workers, collate_fn=collate_fn, persistent_workers=cfg.dataset.workers > 1, - pin_memory=False, # True, + pin_memory=False, worker_init_fn=_seed_worker, **kwargs, ) diff --git a/vall_e/utils/sampler.py b/vall_e/utils/sampler.py index 9e5dc1a..6c8a1dc 100644 --- a/vall_e/utils/sampler.py +++ b/vall_e/utils/sampler.py @@ -9,14 +9,17 @@ from .distributed import global_rank, local_rank, world_size # Randomly picks an index from an array of indices class PoolSampler(): - def __init__( self, pool = [], keep_all = False ): + def __init__( self, pool = [], keep_all = False, shuffle = False ): self.length = len(pool) + self.shuffle = shuffle self.global_pool = pool if keep_all else None self.global_indices = [ i for i in range(self.length) ] self.reset() def reset(self): self.current_pool = [ i for i in self.global_indices ] + if self.shuffle: + random(self.current_pool) def sample(self, pool = None): if pool is None: @@ -78,9 +81,10 @@ class OrderedSampler(Sampler): # Like the above, but will batch based on token count class BatchedOrderedSampler(Sampler): - def __init__( self, buckets, max_duration=0, max_batch_size=0 ): + def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ): self.position = 0 self.batches = [] + self.shuffle = shuffle assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0" @@ -105,12 +109,17 @@ class BatchedOrderedSampler(Sampler): current_index += 1 current_size += duration + if self.shuffle: + random.shuffle(self.batches) + def __len__(self): return len(self.batches) def __iter__(self): if self.position >= len(self.batches): self.position = 0 + if self.shuffle: + random.shuffle(self.batches) while self.position < len(self.batches): yield self.batches[self.position]