another dataloader optimization

This commit is contained in:
mrq 2025-03-15 20:18:58 -05:00
parent bee2688dea
commit c5475ebc91

View File

@ -828,29 +828,27 @@ class Dataset(_Dataset):
self.training = training
self.dataset_type = "training" if self.training else "validation"
self.dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation)
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
self.sampler_order = cfg.dataset.sample_order
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset))
dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation)
self.dataset_hash_key = cfg.dataset.hash_key(dataset)
self.duration = 0
self.duration_buckets = {}
self.current_index = 0
self.batch_size = cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
if len(self.dataset) == 0:
self.dataset = cfg.dataset.training
# hard error because I kept getting tricked by this myself
if self.sampler_order == "duration" and self.sampler_type != "path":
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
# dict that maps [speaker][id] to (duration, similar utterances)
self.metadata = _load_dataset_metadata(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
self.metadata = _load_dataset_metadata(dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
if len(self.metadata) == 0:
raise Exception(f'Empty dataset for {self.dataset_type}')
# cull speakers with too little utterances
prune_keys = [ speaker for speaker in self.metadata.keys() if len(self.metadata[speaker]) < cfg.dataset.min_utterances ]
@ -859,20 +857,13 @@ class Dataset(_Dataset):
self.paths = []
self.speakers = list(self.metadata.keys())
for speaker_id, speaker in enumerate(self.speakers):
utterances = len(self.metadata[speaker])
self.paths += [ (speaker_id, utterance_id) for utterance_id in range( utterances ) ]
self.paths = [ ((speaker_id, utterance_id), self.metadata[speaker][utterance][0]) for speaker_id, speaker in enumerate(self.speakers) for utterance_id, utterance in enumerate(self.metadata[speaker].keys()) ]
# split dataset accordingly per GPU
if cfg.distributed and self.training:
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
# store in corresponding bucket
for (speaker_id, utterance_id) in self.paths:
speaker = self.speakers[speaker_id]
utterance = list(self.metadata[speaker].keys())[utterance_id]
duration, _ = self.metadata[speaker][utterance]
for ((speaker_id, utterance_id), duration) in self.paths:
self.duration += duration
# only calc duration if we're going to order by duration