another dataloader optimization
This commit is contained in:
parent
bee2688dea
commit
c5475ebc91
|
@ -828,29 +828,27 @@ class Dataset(_Dataset):
|
|||
|
||||
self.training = training
|
||||
self.dataset_type = "training" if self.training else "validation"
|
||||
self.dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation)
|
||||
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
|
||||
self.sampler_order = cfg.dataset.sample_order
|
||||
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
|
||||
|
||||
self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset))
|
||||
dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation)
|
||||
self.dataset_hash_key = cfg.dataset.hash_key(dataset)
|
||||
|
||||
self.duration = 0
|
||||
self.duration_buckets = {}
|
||||
self.current_index = 0
|
||||
self.batch_size = cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
|
||||
|
||||
# to-do: do not do validation if there's nothing in the validation
|
||||
# this just makes it be happy
|
||||
if len(self.dataset) == 0:
|
||||
self.dataset = cfg.dataset.training
|
||||
|
||||
# hard error because I kept getting tricked by this myself
|
||||
if self.sampler_order == "duration" and self.sampler_type != "path":
|
||||
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
|
||||
|
||||
# dict that maps [speaker][id] to (duration, similar utterances)
|
||||
self.metadata = _load_dataset_metadata(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
|
||||
self.metadata = _load_dataset_metadata(dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
|
||||
|
||||
if len(self.metadata) == 0:
|
||||
raise Exception(f'Empty dataset for {self.dataset_type}')
|
||||
|
||||
# cull speakers with too little utterances
|
||||
prune_keys = [ speaker for speaker in self.metadata.keys() if len(self.metadata[speaker]) < cfg.dataset.min_utterances ]
|
||||
|
@ -859,20 +857,13 @@ class Dataset(_Dataset):
|
|||
|
||||
self.paths = []
|
||||
self.speakers = list(self.metadata.keys())
|
||||
for speaker_id, speaker in enumerate(self.speakers):
|
||||
utterances = len(self.metadata[speaker])
|
||||
self.paths += [ (speaker_id, utterance_id) for utterance_id in range( utterances ) ]
|
||||
self.paths = [ ((speaker_id, utterance_id), self.metadata[speaker][utterance][0]) for speaker_id, speaker in enumerate(self.speakers) for utterance_id, utterance in enumerate(self.metadata[speaker].keys()) ]
|
||||
|
||||
# split dataset accordingly per GPU
|
||||
if cfg.distributed and self.training:
|
||||
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
||||
|
||||
# store in corresponding bucket
|
||||
for (speaker_id, utterance_id) in self.paths:
|
||||
speaker = self.speakers[speaker_id]
|
||||
utterance = list(self.metadata[speaker].keys())[utterance_id]
|
||||
|
||||
duration, _ = self.metadata[speaker][utterance]
|
||||
for ((speaker_id, utterance_id), duration) in self.paths:
|
||||
self.duration += duration
|
||||
|
||||
# only calc duration if we're going to order by duration
|
||||
|
|
Loading…
Reference in New Issue
Block a user