Record load times in fast_paired_dataset

This commit is contained in:
James Betker 2022-02-07 15:45:38 -07:00
parent 65a546c4d7
commit c24682c668

View File

@ -2,6 +2,7 @@ import hashlib
import os import os
import random import random
import sys import sys
import time
from itertools import groupby from itertools import groupby
import torch import torch
@ -69,6 +70,9 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
self.tokenizer = CharacterTokenizer() self.tokenizer = CharacterTokenizer()
self.skipped_items = 0 # records how many items are skipped when accessing an index. self.skipped_items = 0 # records how many items are skipped when accessing an index.
self.load_times = torch.zeros((256,))
self.load_ind = 0
def get_wav_text_pair(self, audiopath_and_text): def get_wav_text_pair(self, audiopath_and_text):
# separate filename and text # separate filename and text
audiopath, text = audiopath_and_text[0], audiopath_and_text[1] audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
@ -148,6 +152,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
} }
def __getitem__(self, index): def __getitem__(self, index):
start = time.time()
self.skipped_items += 1 self.skipped_items += 1
apt = self.load_random_line() apt = self.load_random_line()
try: try:
@ -185,6 +190,11 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0])) aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
if tseq.shape[0] != self.max_text_len: if tseq.shape[0] != self.max_text_len:
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0])) tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
elapsed = time.time() - start
self.load_times[self.load_ind] = elapsed
self.load_ind = (self.load_ind + 1) % len(self.load_times)
res = { res = {
'real_text': text, 'real_text': text,
'padded_text': tseq, 'padded_text': tseq,
@ -195,12 +205,14 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
'wav_lengths': torch.tensor(orig_output, dtype=torch.long), 'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
'filenames': path, 'filenames': path,
'skipped_items': actually_skipped_items, 'skipped_items': actually_skipped_items,
'load_time': self.load_times.mean()
} }
if self.load_conditioning: if self.load_conditioning:
res['conditioning'] = cond res['conditioning'] = cond
res['conditioning_contains_self'] = cond_is_self res['conditioning_contains_self'] = cond_is_self
if self.produce_ctc_metadata: if self.produce_ctc_metadata:
res.update(self.get_ctc_metadata(raw_codes)) res.update(self.get_ctc_metadata(raw_codes))
return res return res
def __len__(self): def __len__(self):
@ -213,6 +225,7 @@ class FastPairedVoiceDebugger:
self.loaded_items = 0 self.loaded_items = 0
self.self_conditioning_items = 0 self.self_conditioning_items = 0
self.unique_files = set() self.unique_files = set()
self.load_time = 0
def get_state(self): def get_state(self):
return {'total_items': self.total_items, return {'total_items': self.total_items,
@ -228,6 +241,7 @@ class FastPairedVoiceDebugger:
def update(self, batch): def update(self, batch):
self.total_items += batch['wav'].shape[0] self.total_items += batch['wav'].shape[0]
self.loaded_items += batch['skipped_items'].sum().item() self.loaded_items += batch['skipped_items'].sum().item()
self.load_time = batch['load_time'].mean().item()
for filename in batch['filenames']: for filename in batch['filenames']:
self.unique_files.add(hashlib.sha256(filename.encode('utf-8'))) self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
if 'conditioning' in batch.keys(): if 'conditioning' in batch.keys():
@ -238,12 +252,13 @@ class FastPairedVoiceDebugger:
'total_samples_loaded': self.total_items, 'total_samples_loaded': self.total_items,
'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items, 'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items,
'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items, 'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items,
'unique_files_loaded': len(self.unique_files) 'unique_files_loaded': len(self.unique_files),
'load_time': self.load_time,
} }
if __name__ == '__main__': if __name__ == '__main__':
batch_sz = 16 batch_sz = 256
params = { params = {
'mode': 'fast_paired_voice_audio', 'mode': 'fast_paired_voice_audio',
'path': ['Y:\\libritts\\train-clean-360\\transcribed-w2v.tsv', 'Y:\\clips\\books1\\transcribed-w2v.tsv'], 'path': ['Y:\\libritts\\train-clean-360\\transcribed-w2v.tsv', 'Y:\\clips\\books1\\transcribed-w2v.tsv'],
@ -255,8 +270,8 @@ if __name__ == '__main__':
'sample_rate': 22050, 'sample_rate': 22050,
'load_conditioning': True, 'load_conditioning': True,
'num_conditioning_candidates': 1, 'num_conditioning_candidates': 1,
'conditioning_length': 44000, 'conditioning_length': 66000,
'use_bpe_tokenizer': False, 'use_bpe_tokenizer': True,
'load_aligned_codes': True, 'load_aligned_codes': True,
'produce_ctc_metadata': True, 'produce_ctc_metadata': True,
} }
@ -275,10 +290,11 @@ if __name__ == '__main__':
max_pads, max_repeats = 0, 0 max_pads, max_repeats = 0, 0
for i, b in tqdm(enumerate(dl)): for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz): for ib in range(batch_sz):
max_pads = max(max_pads, b['ctc_pads'].max()) #max_pads = max(max_pads, b['ctc_pads'].max())
max_repeats = max(max_repeats, b['ctc_repeats'].max()) #max_repeats = max(max_repeats, b['ctc_repeats'].max())
print(f'{i} {ib} {b["real_text"][ib]}') #print(f'{i} {ib} {b["real_text"][ib]}')
#save(b, i, ib, 'wav') #save(b, i, ib, 'wav')
pass
#if i > 5: #if i > 5:
# break # break
print(max_pads, max_repeats) print(max_pads, max_repeats)