From c24682c668b0d423be9415829953483f5120f12f Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 7 Feb 2022 15:45:38 -0700 Subject: [PATCH] Record load times in fast_paired_dataset --- codes/data/audio/fast_paired_dataset.py | 30 +++++++++++++++++++------ 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/codes/data/audio/fast_paired_dataset.py b/codes/data/audio/fast_paired_dataset.py index 11d9377f..6dca25fd 100644 --- a/codes/data/audio/fast_paired_dataset.py +++ b/codes/data/audio/fast_paired_dataset.py @@ -2,6 +2,7 @@ import hashlib import os import random import sys +import time from itertools import groupby import torch @@ -69,6 +70,9 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): self.tokenizer = CharacterTokenizer() self.skipped_items = 0 # records how many items are skipped when accessing an index. + self.load_times = torch.zeros((256,)) + self.load_ind = 0 + def get_wav_text_pair(self, audiopath_and_text): # separate filename and text audiopath, text = audiopath_and_text[0], audiopath_and_text[1] @@ -148,6 +152,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): } def __getitem__(self, index): + start = time.time() self.skipped_items += 1 apt = self.load_random_line() try: @@ -185,6 +190,11 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0])) if tseq.shape[0] != self.max_text_len: tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0])) + + elapsed = time.time() - start + self.load_times[self.load_ind] = elapsed + self.load_ind = (self.load_ind + 1) % len(self.load_times) + res = { 'real_text': text, 'padded_text': tseq, @@ -195,12 +205,14 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): 'wav_lengths': torch.tensor(orig_output, dtype=torch.long), 'filenames': path, 'skipped_items': actually_skipped_items, + 'load_time': self.load_times.mean() } if self.load_conditioning: res['conditioning'] = cond res['conditioning_contains_self'] = cond_is_self if self.produce_ctc_metadata: res.update(self.get_ctc_metadata(raw_codes)) + return res def __len__(self): @@ -213,6 +225,7 @@ class FastPairedVoiceDebugger: self.loaded_items = 0 self.self_conditioning_items = 0 self.unique_files = set() + self.load_time = 0 def get_state(self): return {'total_items': self.total_items, @@ -228,6 +241,7 @@ class FastPairedVoiceDebugger: def update(self, batch): self.total_items += batch['wav'].shape[0] self.loaded_items += batch['skipped_items'].sum().item() + self.load_time = batch['load_time'].mean().item() for filename in batch['filenames']: self.unique_files.add(hashlib.sha256(filename.encode('utf-8'))) if 'conditioning' in batch.keys(): @@ -238,12 +252,13 @@ class FastPairedVoiceDebugger: 'total_samples_loaded': self.total_items, 'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items, 'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items, - 'unique_files_loaded': len(self.unique_files) + 'unique_files_loaded': len(self.unique_files), + 'load_time': self.load_time, } if __name__ == '__main__': - batch_sz = 16 + batch_sz = 256 params = { 'mode': 'fast_paired_voice_audio', 'path': ['Y:\\libritts\\train-clean-360\\transcribed-w2v.tsv', 'Y:\\clips\\books1\\transcribed-w2v.tsv'], @@ -255,8 +270,8 @@ if __name__ == '__main__': 'sample_rate': 22050, 'load_conditioning': True, 'num_conditioning_candidates': 1, - 'conditioning_length': 44000, - 'use_bpe_tokenizer': False, + 'conditioning_length': 66000, + 'use_bpe_tokenizer': True, 'load_aligned_codes': True, 'produce_ctc_metadata': True, } @@ -275,10 +290,11 @@ if __name__ == '__main__': max_pads, max_repeats = 0, 0 for i, b in tqdm(enumerate(dl)): for ib in range(batch_sz): - max_pads = max(max_pads, b['ctc_pads'].max()) - max_repeats = max(max_repeats, b['ctc_repeats'].max()) - print(f'{i} {ib} {b["real_text"][ib]}') + #max_pads = max(max_pads, b['ctc_pads'].max()) + #max_repeats = max(max_repeats, b['ctc_repeats'].max()) + #print(f'{i} {ib} {b["real_text"][ib]}') #save(b, i, ib, 'wav') + pass #if i > 5: # break print(max_pads, max_repeats)