store dataset hash alongside state dict so it can be ignored if mismatched

This commit is contained in:
mrq 2024-11-11 18:16:56 -06:00
parent f7b8b1e825
commit 354f8e059d
3 changed files with 28 additions and 30 deletions

View File

@ -6,6 +6,14 @@ Most of these settings live under `cfg.dataset`.
## Dataset
The provided reference model was trained on `?`k hours of audio with a mix of:
* LibriTTS-R's entire dataset
* `small`+`medium`+`duplicate` portions of LibriVox
* Emilia's German, French, and Japanese dataset
* 12K hours of a privately sourced corpus of 425 audiobooks
* a small portion of Emilia's English dataset
* a personal small corpus of transcribed utterances from a selection of video games
### Leverage Your Own Dataset
If you already have a dataset you want, for example, your own large corpus or for finetuning, you can use your own dataset instead.

View File

@ -832,16 +832,15 @@ class Dataset(_Dataset):
max_duration=cfg.dataset.sample_max_duration_batch,
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
shuffle=self.sampler_shuffle,
dataset_hash=self.dataset_hash_key,
)
else:
self.sampler = OrderedSampler( len(self), dataset_hash=self.dataset_hash_key ) if not self.sampler_shuffle else RandomSampler( len(self), dataset_hash=self.dataset_hash_key )
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
self.samplers = {}
self.spkr_samplers = {}
else:
self.sampler = RandomSampler( len(self), dataset_hash=self.dataset_hash_key )
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, paths in self.paths_by_spkr_name.items() }
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, speakers in self.spkrs_by_spkr_group.items() }
self.sampler = RandomSampler( len(self) )
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
# dereference buckets
self.duration_map = None
@ -896,6 +895,10 @@ class Dataset(_Dataset):
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
"spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() },
}
if "dataset_hash_key" not in state_dict:
state_dict["dataset_hash_key"] = self.dataset_hash_key
torch_save(state_dict, path)
def load_state_dict(self, path = None):
@ -909,6 +912,9 @@ class Dataset(_Dataset):
return
state_dict = torch_load(path)
if "dataset_hash_key" in state_dict:
if self.dataset_hash_key != state_dict["dataset_hash_key"]:
return
if self.sampler_type == "path":
state_dict = self.sampler.set_state(state_dict)

View File

@ -9,12 +9,11 @@ from .distributed import global_rank, local_rank, world_size
# Randomly picks an index from an array of indices
class PoolSampler():
def __init__( self, pool = [], keep_all = False, shuffle = False, dataset_hash = None ):
def __init__( self, pool = [], keep_all = False, shuffle = False ):
self.length = len(pool)
self.shuffle = shuffle
self.global_pool = pool if keep_all else None
self.global_indices = [ i for i in range(self.length) ]
self.dataset_hash = dataset_hash
self.reset()
def reset(self):
@ -46,25 +45,21 @@ class PoolSampler():
return self.sample(*args, **kwargs)
def get_state(self):
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool, "dataset_hash": self.dataset_hash }
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
def set_state(self, state):
self.length = state["length"]
self.global_pool = state["global_pool"]
self.global_indices = state["global_indices"]
self.current_pool = state["current_pool"]
# could .pop()
if "dataset_hash" in state:
self.dataset_hash = state["dataset_hash"]
# "Samples" through a fixed sequence from 0 to length
# Necessary for our "shuffle+sort by duration+interleave" sampling method
# Allows saving and loading state
class OrderedSampler(Sampler):
def __init__( self, length, dataset_hash=None ):
def __init__( self, length ):
self.position = 0
self.length = length
self.dataset_hash = dataset_hash
def __len__(self):
return self.length
@ -78,22 +73,18 @@ class OrderedSampler(Sampler):
self.position += 1
def get_state(self):
return { "position": self.position, "length": self.length, "dataset_hash": self.dataset_hash }
return { "position": self.position, "length": self.length }
def set_state(self, state):
self.position = state["position"]
self.length = state["length"]
# could .pop()
if "dataset_hash" in state:
self.dataset_hash = state["dataset_hash"]
# Like the above, but will batch based on token count
class BatchedOrderedSampler(Sampler):
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, dataset_hash=None ):
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
self.position = 0
self.batches = []
self.shuffle = shuffle
self.dataset_hash = dataset_hash
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
@ -135,22 +126,18 @@ class BatchedOrderedSampler(Sampler):
self.position += 1
def get_state(self):
return { "position": self.position, "batches": self.batches, "dataset_hash": self.dataset_hash }
return { "position": self.position, "batches": self.batches }
def set_state(self, state):
self.position = state["position"]
self.batches = state["batches"]
# could .pop()
if "dataset_hash" in state:
self.dataset_hash = state["dataset_hash"]
# Randomly samples indices from a given sequence from 0 to length
# Allows saving and loading state
class RandomSampler(Sampler):
def __init__( self, length, dataset_hash=None ):
def __init__( self, length ):
self.position = 0
self.length = length
self.dataset_hash = dataset_hash
self.generator = torch.Generator()
self.perm = torch.randperm(self.length, generator=self.generator)
@ -168,13 +155,10 @@ class RandomSampler(Sampler):
self.position += 1
def get_state(self):
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state(), "dataset_hash": self.dataset_hash }
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
def set_state(self, state):
self.position = state["position"]
self.length = state["length"]
self.perm = state["perm"]
self.generator.set_state(state["generator"])
# could .pop()
if "dataset_hash" in state:
self.dataset_hash = state["dataset_hash"]
self.generator.set_state(state["generator"])