From c5475ebc915fb10bcf156e9acd3bd4af355d4176 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 15 Mar 2025 20:18:58 -0500 Subject: [PATCH] another dataloader optimization --- vall_e/data.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index fb5bfdb..7c9078e 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -828,29 +828,27 @@ class Dataset(_Dataset): self.training = training self.dataset_type = "training" if self.training else "validation" - self.dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation) self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path" self.sampler_order = cfg.dataset.sample_order self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True - self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset)) + dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation) + self.dataset_hash_key = cfg.dataset.hash_key(dataset) self.duration = 0 self.duration_buckets = {} self.current_index = 0 self.batch_size = cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size - # to-do: do not do validation if there's nothing in the validation - # this just makes it be happy - if len(self.dataset) == 0: - self.dataset = cfg.dataset.training - # hard error because I kept getting tricked by this myself if self.sampler_order == "duration" and self.sampler_type != "path": raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.') # dict that maps [speaker][id] to (duration, similar utterances) - self.metadata = _load_dataset_metadata(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key) + self.metadata = _load_dataset_metadata(dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key) + + if len(self.metadata) == 0: + raise Exception(f'Empty dataset for {self.dataset_type}') # cull speakers with too little utterances prune_keys = [ speaker for speaker in self.metadata.keys() if len(self.metadata[speaker]) < cfg.dataset.min_utterances ] @@ -859,20 +857,13 @@ class Dataset(_Dataset): self.paths = [] self.speakers = list(self.metadata.keys()) - for speaker_id, speaker in enumerate(self.speakers): - utterances = len(self.metadata[speaker]) - self.paths += [ (speaker_id, utterance_id) for utterance_id in range( utterances ) ] + self.paths = [ ((speaker_id, utterance_id), self.metadata[speaker][utterance][0]) for speaker_id, speaker in enumerate(self.speakers) for utterance_id, utterance in enumerate(self.metadata[speaker].keys()) ] # split dataset accordingly per GPU if cfg.distributed and self.training: self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ] - # store in corresponding bucket - for (speaker_id, utterance_id) in self.paths: - speaker = self.speakers[speaker_id] - utterance = list(self.metadata[speaker].keys())[utterance_id] - - duration, _ = self.metadata[speaker][utterance] + for ((speaker_id, utterance_id), duration) in self.paths: self.duration += duration # only calc duration if we're going to order by duration