another dataloader optimization

This commit is contained in:
mrq 2025-03-15 20:18:58 -05:00
parent bee2688dea
commit c5475ebc91

View File

@ -828,29 +828,27 @@ class Dataset(_Dataset):
self.training = training self.training = training
self.dataset_type = "training" if self.training else "validation" self.dataset_type = "training" if self.training else "validation"
self.dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation)
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path" self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
self.sampler_order = cfg.dataset.sample_order self.sampler_order = cfg.dataset.sample_order
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset)) dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation)
self.dataset_hash_key = cfg.dataset.hash_key(dataset)
self.duration = 0 self.duration = 0
self.duration_buckets = {} self.duration_buckets = {}
self.current_index = 0 self.current_index = 0
self.batch_size = cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size self.batch_size = cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
if len(self.dataset) == 0:
self.dataset = cfg.dataset.training
# hard error because I kept getting tricked by this myself # hard error because I kept getting tricked by this myself
if self.sampler_order == "duration" and self.sampler_type != "path": if self.sampler_order == "duration" and self.sampler_type != "path":
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.') raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
# dict that maps [speaker][id] to (duration, similar utterances) # dict that maps [speaker][id] to (duration, similar utterances)
self.metadata = _load_dataset_metadata(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key) self.metadata = _load_dataset_metadata(dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
if len(self.metadata) == 0:
raise Exception(f'Empty dataset for {self.dataset_type}')
# cull speakers with too little utterances # cull speakers with too little utterances
prune_keys = [ speaker for speaker in self.metadata.keys() if len(self.metadata[speaker]) < cfg.dataset.min_utterances ] prune_keys = [ speaker for speaker in self.metadata.keys() if len(self.metadata[speaker]) < cfg.dataset.min_utterances ]
@ -859,20 +857,13 @@ class Dataset(_Dataset):
self.paths = [] self.paths = []
self.speakers = list(self.metadata.keys()) self.speakers = list(self.metadata.keys())
for speaker_id, speaker in enumerate(self.speakers): self.paths = [ ((speaker_id, utterance_id), self.metadata[speaker][utterance][0]) for speaker_id, speaker in enumerate(self.speakers) for utterance_id, utterance in enumerate(self.metadata[speaker].keys()) ]
utterances = len(self.metadata[speaker])
self.paths += [ (speaker_id, utterance_id) for utterance_id in range( utterances ) ]
# split dataset accordingly per GPU # split dataset accordingly per GPU
if cfg.distributed and self.training: if cfg.distributed and self.training:
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ] self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
# store in corresponding bucket for ((speaker_id, utterance_id), duration) in self.paths:
for (speaker_id, utterance_id) in self.paths:
speaker = self.speakers[speaker_id]
utterance = list(self.metadata[speaker].keys())[utterance_id]
duration, _ = self.metadata[speaker][utterance]
self.duration += duration self.duration += duration
# only calc duration if we're going to order by duration # only calc duration if we're going to order by duration