store dataset hash alongside state dict so it can be ignored if mismatched
This commit is contained in:
parent
f7b8b1e825
commit
354f8e059d
|
@ -6,6 +6,14 @@ Most of these settings live under `cfg.dataset`.
|
|||
|
||||
## Dataset
|
||||
|
||||
The provided reference model was trained on `?`k hours of audio with a mix of:
|
||||
* LibriTTS-R's entire dataset
|
||||
* `small`+`medium`+`duplicate` portions of LibriVox
|
||||
* Emilia's German, French, and Japanese dataset
|
||||
* 12K hours of a privately sourced corpus of 425 audiobooks
|
||||
* a small portion of Emilia's English dataset
|
||||
* a personal small corpus of transcribed utterances from a selection of video games
|
||||
|
||||
### Leverage Your Own Dataset
|
||||
|
||||
If you already have a dataset you want, for example, your own large corpus or for finetuning, you can use your own dataset instead.
|
||||
|
|
|
@ -832,16 +832,15 @@ class Dataset(_Dataset):
|
|||
max_duration=cfg.dataset.sample_max_duration_batch,
|
||||
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
|
||||
shuffle=self.sampler_shuffle,
|
||||
dataset_hash=self.dataset_hash_key,
|
||||
)
|
||||
else:
|
||||
self.sampler = OrderedSampler( len(self), dataset_hash=self.dataset_hash_key ) if not self.sampler_shuffle else RandomSampler( len(self), dataset_hash=self.dataset_hash_key )
|
||||
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
||||
self.samplers = {}
|
||||
self.spkr_samplers = {}
|
||||
else:
|
||||
self.sampler = RandomSampler( len(self), dataset_hash=self.dataset_hash_key )
|
||||
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
self.sampler = RandomSampler( len(self) )
|
||||
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
|
||||
# dereference buckets
|
||||
self.duration_map = None
|
||||
|
@ -896,6 +895,10 @@ class Dataset(_Dataset):
|
|||
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
|
||||
"spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() },
|
||||
}
|
||||
|
||||
if "dataset_hash_key" not in state_dict:
|
||||
state_dict["dataset_hash_key"] = self.dataset_hash_key
|
||||
|
||||
torch_save(state_dict, path)
|
||||
|
||||
def load_state_dict(self, path = None):
|
||||
|
@ -909,6 +912,9 @@ class Dataset(_Dataset):
|
|||
return
|
||||
|
||||
state_dict = torch_load(path)
|
||||
if "dataset_hash_key" in state_dict:
|
||||
if self.dataset_hash_key != state_dict["dataset_hash_key"]:
|
||||
return
|
||||
|
||||
if self.sampler_type == "path":
|
||||
state_dict = self.sampler.set_state(state_dict)
|
||||
|
|
|
@ -9,12 +9,11 @@ from .distributed import global_rank, local_rank, world_size
|
|||
|
||||
# Randomly picks an index from an array of indices
|
||||
class PoolSampler():
|
||||
def __init__( self, pool = [], keep_all = False, shuffle = False, dataset_hash = None ):
|
||||
def __init__( self, pool = [], keep_all = False, shuffle = False ):
|
||||
self.length = len(pool)
|
||||
self.shuffle = shuffle
|
||||
self.global_pool = pool if keep_all else None
|
||||
self.global_indices = [ i for i in range(self.length) ]
|
||||
self.dataset_hash = dataset_hash
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
|
@ -46,25 +45,21 @@ class PoolSampler():
|
|||
return self.sample(*args, **kwargs)
|
||||
|
||||
def get_state(self):
|
||||
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool, "dataset_hash": self.dataset_hash }
|
||||
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
|
||||
|
||||
def set_state(self, state):
|
||||
self.length = state["length"]
|
||||
self.global_pool = state["global_pool"]
|
||||
self.global_indices = state["global_indices"]
|
||||
self.current_pool = state["current_pool"]
|
||||
# could .pop()
|
||||
if "dataset_hash" in state:
|
||||
self.dataset_hash = state["dataset_hash"]
|
||||
|
||||
# "Samples" through a fixed sequence from 0 to length
|
||||
# Necessary for our "shuffle+sort by duration+interleave" sampling method
|
||||
# Allows saving and loading state
|
||||
class OrderedSampler(Sampler):
|
||||
def __init__( self, length, dataset_hash=None ):
|
||||
def __init__( self, length ):
|
||||
self.position = 0
|
||||
self.length = length
|
||||
self.dataset_hash = dataset_hash
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
@ -78,22 +73,18 @@ class OrderedSampler(Sampler):
|
|||
self.position += 1
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "length": self.length, "dataset_hash": self.dataset_hash }
|
||||
return { "position": self.position, "length": self.length }
|
||||
|
||||
def set_state(self, state):
|
||||
self.position = state["position"]
|
||||
self.length = state["length"]
|
||||
# could .pop()
|
||||
if "dataset_hash" in state:
|
||||
self.dataset_hash = state["dataset_hash"]
|
||||
|
||||
# Like the above, but will batch based on token count
|
||||
class BatchedOrderedSampler(Sampler):
|
||||
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, dataset_hash=None ):
|
||||
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
|
||||
self.position = 0
|
||||
self.batches = []
|
||||
self.shuffle = shuffle
|
||||
self.dataset_hash = dataset_hash
|
||||
|
||||
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
||||
|
||||
|
@ -135,22 +126,18 @@ class BatchedOrderedSampler(Sampler):
|
|||
self.position += 1
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "batches": self.batches, "dataset_hash": self.dataset_hash }
|
||||
return { "position": self.position, "batches": self.batches }
|
||||
|
||||
def set_state(self, state):
|
||||
self.position = state["position"]
|
||||
self.batches = state["batches"]
|
||||
# could .pop()
|
||||
if "dataset_hash" in state:
|
||||
self.dataset_hash = state["dataset_hash"]
|
||||
|
||||
# Randomly samples indices from a given sequence from 0 to length
|
||||
# Allows saving and loading state
|
||||
class RandomSampler(Sampler):
|
||||
def __init__( self, length, dataset_hash=None ):
|
||||
def __init__( self, length ):
|
||||
self.position = 0
|
||||
self.length = length
|
||||
self.dataset_hash = dataset_hash
|
||||
|
||||
self.generator = torch.Generator()
|
||||
self.perm = torch.randperm(self.length, generator=self.generator)
|
||||
|
@ -168,13 +155,10 @@ class RandomSampler(Sampler):
|
|||
self.position += 1
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state(), "dataset_hash": self.dataset_hash }
|
||||
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
|
||||
|
||||
def set_state(self, state):
|
||||
self.position = state["position"]
|
||||
self.length = state["length"]
|
||||
self.perm = state["perm"]
|
||||
self.generator.set_state(state["generator"])
|
||||
# could .pop()
|
||||
if "dataset_hash" in state:
|
||||
self.dataset_hash = state["dataset_hash"]
|
||||
self.generator.set_state(state["generator"])
|
Loading…
Reference in New Issue
Block a user