From 7fef7fb9ffea89ed6253086c5179a623bb030008 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 20 Jan 2022 21:49:38 -0700 Subject: [PATCH] Update fast_paired_dataset to report how many audio files it is actually using --- codes/data/__init__.py | 5 +++- codes/data/audio/fast_paired_dataset.py | 37 +++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 7531b87d..3d43e1c0 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -106,7 +106,10 @@ def create_dataset(dataset_opt, return_collate=False): def get_dataset_debugger(dataset_opt): mode = dataset_opt['mode'] - if mode == 'paired_voice_audio' or mode == 'fast_paired_voice_audio': + if mode == 'paired_voice_audio': from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger return PairedVoiceDebugger() + elif mode == 'fast_paired_voice_audio': + from data.audio.fast_paired_dataset import FastPairedVoiceDebugger + return FastPairedVoiceDebugger() return None \ No newline at end of file diff --git a/codes/data/audio/fast_paired_dataset.py b/codes/data/audio/fast_paired_dataset.py index 34a62ebd..f125beef 100644 --- a/codes/data/audio/fast_paired_dataset.py +++ b/codes/data/audio/fast_paired_dataset.py @@ -1,3 +1,4 @@ +import hashlib import os import os import random @@ -171,6 +172,42 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well. +class FastPairedVoiceDebugger: + def __init__(self): + self.total_items = 0 + self.loaded_items = 0 + self.self_conditioning_items = 0 + self.unique_files = set() + + def get_state(self): + return {'total_items': self.total_items, + 'loaded_items': self.loaded_items, + 'self_conditioning_items': self.self_conditioning_items, + 'unique_files_loaded': self.unique_files} + + def load_state(self, state): + if isinstance(state, dict): + self.total_items = opt_get(state, ['total_items'], 0) + self.loaded_items = opt_get(state, ['loaded_items'], 0) + self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0) + + def update(self, batch): + self.total_items += batch['wav'].shape[0] + self.loaded_items += batch['skipped_items'].sum().item() + for filename in batch['filenames']: + self.unique_files.add(hashlib.sha256(filename.encode('utf-8'))) + if 'conditioning' in batch.keys(): + self.self_conditioning_items += batch['conditioning_contains_self'].sum().item() + + def get_debugging_map(self): + return { + 'total_samples_loaded': self.total_items, + 'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items, + 'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items, + 'unique_files_loaded': len(self.unique_files) + } + + if __name__ == '__main__': batch_sz = 16 params = {