Update fast_paired_dataset to report how many audio files it is actually using

This commit is contained in:
James Betker 2022-01-20 21:49:38 -07:00
parent ed35cfe393
commit 7fef7fb9ff
2 changed files with 41 additions and 1 deletions

View File

@ -106,7 +106,10 @@ def create_dataset(dataset_opt, return_collate=False):
def get_dataset_debugger(dataset_opt):
mode = dataset_opt['mode']
if mode == 'paired_voice_audio' or mode == 'fast_paired_voice_audio':
if mode == 'paired_voice_audio':
from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger
return PairedVoiceDebugger()
elif mode == 'fast_paired_voice_audio':
from data.audio.fast_paired_dataset import FastPairedVoiceDebugger
return FastPairedVoiceDebugger()
return None

View File

@ -1,3 +1,4 @@
import hashlib
import os
import os
import random
@ -171,6 +172,42 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well.
class FastPairedVoiceDebugger:
def __init__(self):
self.total_items = 0
self.loaded_items = 0
self.self_conditioning_items = 0
self.unique_files = set()
def get_state(self):
return {'total_items': self.total_items,
'loaded_items': self.loaded_items,
'self_conditioning_items': self.self_conditioning_items,
'unique_files_loaded': self.unique_files}
def load_state(self, state):
if isinstance(state, dict):
self.total_items = opt_get(state, ['total_items'], 0)
self.loaded_items = opt_get(state, ['loaded_items'], 0)
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
def update(self, batch):
self.total_items += batch['wav'].shape[0]
self.loaded_items += batch['skipped_items'].sum().item()
for filename in batch['filenames']:
self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
if 'conditioning' in batch.keys():
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
def get_debugging_map(self):
return {
'total_samples_loaded': self.total_items,
'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items,
'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items,
'unique_files_loaded': len(self.unique_files)
}
if __name__ == '__main__':
batch_sz = 16
params = {