forked from mrq/DL-Art-School
Update fast_paired_dataset to report how many audio files it is actually using
This commit is contained in:
parent
ed35cfe393
commit
7fef7fb9ff
|
@ -106,7 +106,10 @@ def create_dataset(dataset_opt, return_collate=False):
|
|||
|
||||
def get_dataset_debugger(dataset_opt):
|
||||
mode = dataset_opt['mode']
|
||||
if mode == 'paired_voice_audio' or mode == 'fast_paired_voice_audio':
|
||||
if mode == 'paired_voice_audio':
|
||||
from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger
|
||||
return PairedVoiceDebugger()
|
||||
elif mode == 'fast_paired_voice_audio':
|
||||
from data.audio.fast_paired_dataset import FastPairedVoiceDebugger
|
||||
return FastPairedVoiceDebugger()
|
||||
return None
|
|
@ -1,3 +1,4 @@
|
|||
import hashlib
|
||||
import os
|
||||
import os
|
||||
import random
|
||||
|
@ -171,6 +172,42 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well.
|
||||
|
||||
|
||||
class FastPairedVoiceDebugger:
|
||||
def __init__(self):
|
||||
self.total_items = 0
|
||||
self.loaded_items = 0
|
||||
self.self_conditioning_items = 0
|
||||
self.unique_files = set()
|
||||
|
||||
def get_state(self):
|
||||
return {'total_items': self.total_items,
|
||||
'loaded_items': self.loaded_items,
|
||||
'self_conditioning_items': self.self_conditioning_items,
|
||||
'unique_files_loaded': self.unique_files}
|
||||
|
||||
def load_state(self, state):
|
||||
if isinstance(state, dict):
|
||||
self.total_items = opt_get(state, ['total_items'], 0)
|
||||
self.loaded_items = opt_get(state, ['loaded_items'], 0)
|
||||
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
|
||||
|
||||
def update(self, batch):
|
||||
self.total_items += batch['wav'].shape[0]
|
||||
self.loaded_items += batch['skipped_items'].sum().item()
|
||||
for filename in batch['filenames']:
|
||||
self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
|
||||
if 'conditioning' in batch.keys():
|
||||
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
|
||||
|
||||
def get_debugging_map(self):
|
||||
return {
|
||||
'total_samples_loaded': self.total_items,
|
||||
'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items,
|
||||
'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items,
|
||||
'unique_files_loaded': len(self.unique_files)
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch_sz = 16
|
||||
params = {
|
||||
|
|
Loading…
Reference in New Issue
Block a user