From 354f8e059d4cee14f18266bd90aedbdf32338e92 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 11 Nov 2024 18:16:56 -0600 Subject: [PATCH] store dataset hash alongside state dict so it can be ignored if mismatched --- docs/data.md | 8 ++++++++ vall_e/data.py | 16 +++++++++++----- vall_e/utils/sampler.py | 34 +++++++++------------------------- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/docs/data.md b/docs/data.md index c9474c8..47ace4b 100644 --- a/docs/data.md +++ b/docs/data.md @@ -6,6 +6,14 @@ Most of these settings live under `cfg.dataset`. ## Dataset +The provided reference model was trained on `?`k hours of audio with a mix of: +* LibriTTS-R's entire dataset +* `small`+`medium`+`duplicate` portions of LibriVox +* Emilia's German, French, and Japanese dataset +* 12K hours of a privately sourced corpus of 425 audiobooks +* a small portion of Emilia's English dataset +* a personal small corpus of transcribed utterances from a selection of video games + ### Leverage Your Own Dataset If you already have a dataset you want, for example, your own large corpus or for finetuning, you can use your own dataset instead. diff --git a/vall_e/data.py b/vall_e/data.py index b9ded90..72e85ff 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -832,16 +832,15 @@ class Dataset(_Dataset): max_duration=cfg.dataset.sample_max_duration_batch, max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size, shuffle=self.sampler_shuffle, - dataset_hash=self.dataset_hash_key, ) else: - self.sampler = OrderedSampler( len(self), dataset_hash=self.dataset_hash_key ) if not self.sampler_shuffle else RandomSampler( len(self), dataset_hash=self.dataset_hash_key ) + self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) ) self.samplers = {} self.spkr_samplers = {} else: - self.sampler = RandomSampler( len(self), dataset_hash=self.dataset_hash_key ) - self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, paths in self.paths_by_spkr_name.items() } - self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, speakers in self.spkrs_by_spkr_group.items() } + self.sampler = RandomSampler( len(self) ) + self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() } + self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() } # dereference buckets self.duration_map = None @@ -896,6 +895,10 @@ class Dataset(_Dataset): "samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() }, "spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() }, } + + if "dataset_hash_key" not in state_dict: + state_dict["dataset_hash_key"] = self.dataset_hash_key + torch_save(state_dict, path) def load_state_dict(self, path = None): @@ -909,6 +912,9 @@ class Dataset(_Dataset): return state_dict = torch_load(path) + if "dataset_hash_key" in state_dict: + if self.dataset_hash_key != state_dict["dataset_hash_key"]: + return if self.sampler_type == "path": state_dict = self.sampler.set_state(state_dict) diff --git a/vall_e/utils/sampler.py b/vall_e/utils/sampler.py index 2e072ec..d5b6b76 100644 --- a/vall_e/utils/sampler.py +++ b/vall_e/utils/sampler.py @@ -9,12 +9,11 @@ from .distributed import global_rank, local_rank, world_size # Randomly picks an index from an array of indices class PoolSampler(): - def __init__( self, pool = [], keep_all = False, shuffle = False, dataset_hash = None ): + def __init__( self, pool = [], keep_all = False, shuffle = False ): self.length = len(pool) self.shuffle = shuffle self.global_pool = pool if keep_all else None self.global_indices = [ i for i in range(self.length) ] - self.dataset_hash = dataset_hash self.reset() def reset(self): @@ -46,25 +45,21 @@ class PoolSampler(): return self.sample(*args, **kwargs) def get_state(self): - return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool, "dataset_hash": self.dataset_hash } + return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool } def set_state(self, state): self.length = state["length"] self.global_pool = state["global_pool"] self.global_indices = state["global_indices"] self.current_pool = state["current_pool"] - # could .pop() - if "dataset_hash" in state: - self.dataset_hash = state["dataset_hash"] # "Samples" through a fixed sequence from 0 to length # Necessary for our "shuffle+sort by duration+interleave" sampling method # Allows saving and loading state class OrderedSampler(Sampler): - def __init__( self, length, dataset_hash=None ): + def __init__( self, length ): self.position = 0 self.length = length - self.dataset_hash = dataset_hash def __len__(self): return self.length @@ -78,22 +73,18 @@ class OrderedSampler(Sampler): self.position += 1 def get_state(self): - return { "position": self.position, "length": self.length, "dataset_hash": self.dataset_hash } + return { "position": self.position, "length": self.length } def set_state(self, state): self.position = state["position"] self.length = state["length"] - # could .pop() - if "dataset_hash" in state: - self.dataset_hash = state["dataset_hash"] # Like the above, but will batch based on token count class BatchedOrderedSampler(Sampler): - def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, dataset_hash=None ): + def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ): self.position = 0 self.batches = [] self.shuffle = shuffle - self.dataset_hash = dataset_hash assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0" @@ -135,22 +126,18 @@ class BatchedOrderedSampler(Sampler): self.position += 1 def get_state(self): - return { "position": self.position, "batches": self.batches, "dataset_hash": self.dataset_hash } + return { "position": self.position, "batches": self.batches } def set_state(self, state): self.position = state["position"] self.batches = state["batches"] - # could .pop() - if "dataset_hash" in state: - self.dataset_hash = state["dataset_hash"] # Randomly samples indices from a given sequence from 0 to length # Allows saving and loading state class RandomSampler(Sampler): - def __init__( self, length, dataset_hash=None ): + def __init__( self, length ): self.position = 0 self.length = length - self.dataset_hash = dataset_hash self.generator = torch.Generator() self.perm = torch.randperm(self.length, generator=self.generator) @@ -168,13 +155,10 @@ class RandomSampler(Sampler): self.position += 1 def get_state(self): - return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state(), "dataset_hash": self.dataset_hash } + return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() } def set_state(self, state): self.position = state["position"] self.length = state["length"] self.perm = state["perm"] - self.generator.set_state(state["generator"]) - # could .pop() - if "dataset_hash" in state: - self.dataset_hash = state["dataset_hash"] \ No newline at end of file + self.generator.set_state(state["generator"]) \ No newline at end of file