Record load times in fast_paired_dataset
This commit is contained in:
parent
65a546c4d7
commit
c24682c668
|
@ -2,6 +2,7 @@ import hashlib
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -69,6 +70,9 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
self.tokenizer = CharacterTokenizer()
|
self.tokenizer = CharacterTokenizer()
|
||||||
self.skipped_items = 0 # records how many items are skipped when accessing an index.
|
self.skipped_items = 0 # records how many items are skipped when accessing an index.
|
||||||
|
|
||||||
|
self.load_times = torch.zeros((256,))
|
||||||
|
self.load_ind = 0
|
||||||
|
|
||||||
def get_wav_text_pair(self, audiopath_and_text):
|
def get_wav_text_pair(self, audiopath_and_text):
|
||||||
# separate filename and text
|
# separate filename and text
|
||||||
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
||||||
|
@ -148,6 +152,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
start = time.time()
|
||||||
self.skipped_items += 1
|
self.skipped_items += 1
|
||||||
apt = self.load_random_line()
|
apt = self.load_random_line()
|
||||||
try:
|
try:
|
||||||
|
@ -185,6 +190,11 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
|
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
|
||||||
if tseq.shape[0] != self.max_text_len:
|
if tseq.shape[0] != self.max_text_len:
|
||||||
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
|
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
|
||||||
|
|
||||||
|
elapsed = time.time() - start
|
||||||
|
self.load_times[self.load_ind] = elapsed
|
||||||
|
self.load_ind = (self.load_ind + 1) % len(self.load_times)
|
||||||
|
|
||||||
res = {
|
res = {
|
||||||
'real_text': text,
|
'real_text': text,
|
||||||
'padded_text': tseq,
|
'padded_text': tseq,
|
||||||
|
@ -195,12 +205,14 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||||
'filenames': path,
|
'filenames': path,
|
||||||
'skipped_items': actually_skipped_items,
|
'skipped_items': actually_skipped_items,
|
||||||
|
'load_time': self.load_times.mean()
|
||||||
}
|
}
|
||||||
if self.load_conditioning:
|
if self.load_conditioning:
|
||||||
res['conditioning'] = cond
|
res['conditioning'] = cond
|
||||||
res['conditioning_contains_self'] = cond_is_self
|
res['conditioning_contains_self'] = cond_is_self
|
||||||
if self.produce_ctc_metadata:
|
if self.produce_ctc_metadata:
|
||||||
res.update(self.get_ctc_metadata(raw_codes))
|
res.update(self.get_ctc_metadata(raw_codes))
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -213,6 +225,7 @@ class FastPairedVoiceDebugger:
|
||||||
self.loaded_items = 0
|
self.loaded_items = 0
|
||||||
self.self_conditioning_items = 0
|
self.self_conditioning_items = 0
|
||||||
self.unique_files = set()
|
self.unique_files = set()
|
||||||
|
self.load_time = 0
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return {'total_items': self.total_items,
|
return {'total_items': self.total_items,
|
||||||
|
@ -228,6 +241,7 @@ class FastPairedVoiceDebugger:
|
||||||
def update(self, batch):
|
def update(self, batch):
|
||||||
self.total_items += batch['wav'].shape[0]
|
self.total_items += batch['wav'].shape[0]
|
||||||
self.loaded_items += batch['skipped_items'].sum().item()
|
self.loaded_items += batch['skipped_items'].sum().item()
|
||||||
|
self.load_time = batch['load_time'].mean().item()
|
||||||
for filename in batch['filenames']:
|
for filename in batch['filenames']:
|
||||||
self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
|
self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
|
||||||
if 'conditioning' in batch.keys():
|
if 'conditioning' in batch.keys():
|
||||||
|
@ -238,12 +252,13 @@ class FastPairedVoiceDebugger:
|
||||||
'total_samples_loaded': self.total_items,
|
'total_samples_loaded': self.total_items,
|
||||||
'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items,
|
'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items,
|
||||||
'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items,
|
'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items,
|
||||||
'unique_files_loaded': len(self.unique_files)
|
'unique_files_loaded': len(self.unique_files),
|
||||||
|
'load_time': self.load_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
batch_sz = 16
|
batch_sz = 256
|
||||||
params = {
|
params = {
|
||||||
'mode': 'fast_paired_voice_audio',
|
'mode': 'fast_paired_voice_audio',
|
||||||
'path': ['Y:\\libritts\\train-clean-360\\transcribed-w2v.tsv', 'Y:\\clips\\books1\\transcribed-w2v.tsv'],
|
'path': ['Y:\\libritts\\train-clean-360\\transcribed-w2v.tsv', 'Y:\\clips\\books1\\transcribed-w2v.tsv'],
|
||||||
|
@ -255,8 +270,8 @@ if __name__ == '__main__':
|
||||||
'sample_rate': 22050,
|
'sample_rate': 22050,
|
||||||
'load_conditioning': True,
|
'load_conditioning': True,
|
||||||
'num_conditioning_candidates': 1,
|
'num_conditioning_candidates': 1,
|
||||||
'conditioning_length': 44000,
|
'conditioning_length': 66000,
|
||||||
'use_bpe_tokenizer': False,
|
'use_bpe_tokenizer': True,
|
||||||
'load_aligned_codes': True,
|
'load_aligned_codes': True,
|
||||||
'produce_ctc_metadata': True,
|
'produce_ctc_metadata': True,
|
||||||
}
|
}
|
||||||
|
@ -275,10 +290,11 @@ if __name__ == '__main__':
|
||||||
max_pads, max_repeats = 0, 0
|
max_pads, max_repeats = 0, 0
|
||||||
for i, b in tqdm(enumerate(dl)):
|
for i, b in tqdm(enumerate(dl)):
|
||||||
for ib in range(batch_sz):
|
for ib in range(batch_sz):
|
||||||
max_pads = max(max_pads, b['ctc_pads'].max())
|
#max_pads = max(max_pads, b['ctc_pads'].max())
|
||||||
max_repeats = max(max_repeats, b['ctc_repeats'].max())
|
#max_repeats = max(max_repeats, b['ctc_repeats'].max())
|
||||||
print(f'{i} {ib} {b["real_text"][ib]}')
|
#print(f'{i} {ib} {b["real_text"][ib]}')
|
||||||
#save(b, i, ib, 'wav')
|
#save(b, i, ib, 'wav')
|
||||||
|
pass
|
||||||
#if i > 5:
|
#if i > 5:
|
||||||
# break
|
# break
|
||||||
print(max_pads, max_repeats)
|
print(max_pads, max_repeats)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user