Record load times in fast_paired_dataset
This commit is contained in:
parent
65a546c4d7
commit
c24682c668
|
@ -2,6 +2,7 @@ import hashlib
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from itertools import groupby
|
||||
|
||||
import torch
|
||||
|
@ -69,6 +70,9 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
self.tokenizer = CharacterTokenizer()
|
||||
self.skipped_items = 0 # records how many items are skipped when accessing an index.
|
||||
|
||||
self.load_times = torch.zeros((256,))
|
||||
self.load_ind = 0
|
||||
|
||||
def get_wav_text_pair(self, audiopath_and_text):
|
||||
# separate filename and text
|
||||
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
||||
|
@ -148,6 +152,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
}
|
||||
|
||||
def __getitem__(self, index):
|
||||
start = time.time()
|
||||
self.skipped_items += 1
|
||||
apt = self.load_random_line()
|
||||
try:
|
||||
|
@ -185,6 +190,11 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
|
||||
if tseq.shape[0] != self.max_text_len:
|
||||
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
|
||||
|
||||
elapsed = time.time() - start
|
||||
self.load_times[self.load_ind] = elapsed
|
||||
self.load_ind = (self.load_ind + 1) % len(self.load_times)
|
||||
|
||||
res = {
|
||||
'real_text': text,
|
||||
'padded_text': tseq,
|
||||
|
@ -195,12 +205,14 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||
'filenames': path,
|
||||
'skipped_items': actually_skipped_items,
|
||||
'load_time': self.load_times.mean()
|
||||
}
|
||||
if self.load_conditioning:
|
||||
res['conditioning'] = cond
|
||||
res['conditioning_contains_self'] = cond_is_self
|
||||
if self.produce_ctc_metadata:
|
||||
res.update(self.get_ctc_metadata(raw_codes))
|
||||
|
||||
return res
|
||||
|
||||
def __len__(self):
|
||||
|
@ -213,6 +225,7 @@ class FastPairedVoiceDebugger:
|
|||
self.loaded_items = 0
|
||||
self.self_conditioning_items = 0
|
||||
self.unique_files = set()
|
||||
self.load_time = 0
|
||||
|
||||
def get_state(self):
|
||||
return {'total_items': self.total_items,
|
||||
|
@ -228,6 +241,7 @@ class FastPairedVoiceDebugger:
|
|||
def update(self, batch):
|
||||
self.total_items += batch['wav'].shape[0]
|
||||
self.loaded_items += batch['skipped_items'].sum().item()
|
||||
self.load_time = batch['load_time'].mean().item()
|
||||
for filename in batch['filenames']:
|
||||
self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
|
||||
if 'conditioning' in batch.keys():
|
||||
|
@ -238,12 +252,13 @@ class FastPairedVoiceDebugger:
|
|||
'total_samples_loaded': self.total_items,
|
||||
'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items,
|
||||
'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items,
|
||||
'unique_files_loaded': len(self.unique_files)
|
||||
'unique_files_loaded': len(self.unique_files),
|
||||
'load_time': self.load_time,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch_sz = 16
|
||||
batch_sz = 256
|
||||
params = {
|
||||
'mode': 'fast_paired_voice_audio',
|
||||
'path': ['Y:\\libritts\\train-clean-360\\transcribed-w2v.tsv', 'Y:\\clips\\books1\\transcribed-w2v.tsv'],
|
||||
|
@ -255,8 +270,8 @@ if __name__ == '__main__':
|
|||
'sample_rate': 22050,
|
||||
'load_conditioning': True,
|
||||
'num_conditioning_candidates': 1,
|
||||
'conditioning_length': 44000,
|
||||
'use_bpe_tokenizer': False,
|
||||
'conditioning_length': 66000,
|
||||
'use_bpe_tokenizer': True,
|
||||
'load_aligned_codes': True,
|
||||
'produce_ctc_metadata': True,
|
||||
}
|
||||
|
@ -275,10 +290,11 @@ if __name__ == '__main__':
|
|||
max_pads, max_repeats = 0, 0
|
||||
for i, b in tqdm(enumerate(dl)):
|
||||
for ib in range(batch_sz):
|
||||
max_pads = max(max_pads, b['ctc_pads'].max())
|
||||
max_repeats = max(max_repeats, b['ctc_repeats'].max())
|
||||
print(f'{i} {ib} {b["real_text"][ib]}')
|
||||
#max_pads = max(max_pads, b['ctc_pads'].max())
|
||||
#max_repeats = max(max_repeats, b['ctc_repeats'].max())
|
||||
#print(f'{i} {ib} {b["real_text"][ib]}')
|
||||
#save(b, i, ib, 'wav')
|
||||
pass
|
||||
#if i > 5:
|
||||
# break
|
||||
print(max_pads, max_repeats)
|
||||
|
|
Loading…
Reference in New Issue
Block a user