forked from mrq/DL-Art-School
Update fast_paired_dataset to report how many audio files it is actually using
This commit is contained in:
parent
ed35cfe393
commit
7fef7fb9ff
|
@ -106,7 +106,10 @@ def create_dataset(dataset_opt, return_collate=False):
|
||||||
|
|
||||||
def get_dataset_debugger(dataset_opt):
|
def get_dataset_debugger(dataset_opt):
|
||||||
mode = dataset_opt['mode']
|
mode = dataset_opt['mode']
|
||||||
if mode == 'paired_voice_audio' or mode == 'fast_paired_voice_audio':
|
if mode == 'paired_voice_audio':
|
||||||
from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger
|
from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger
|
||||||
return PairedVoiceDebugger()
|
return PairedVoiceDebugger()
|
||||||
|
elif mode == 'fast_paired_voice_audio':
|
||||||
|
from data.audio.fast_paired_dataset import FastPairedVoiceDebugger
|
||||||
|
return FastPairedVoiceDebugger()
|
||||||
return None
|
return None
|
|
@ -1,3 +1,4 @@
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
@ -171,6 +172,42 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well.
|
return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well.
|
||||||
|
|
||||||
|
|
||||||
|
class FastPairedVoiceDebugger:
|
||||||
|
def __init__(self):
|
||||||
|
self.total_items = 0
|
||||||
|
self.loaded_items = 0
|
||||||
|
self.self_conditioning_items = 0
|
||||||
|
self.unique_files = set()
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return {'total_items': self.total_items,
|
||||||
|
'loaded_items': self.loaded_items,
|
||||||
|
'self_conditioning_items': self.self_conditioning_items,
|
||||||
|
'unique_files_loaded': self.unique_files}
|
||||||
|
|
||||||
|
def load_state(self, state):
|
||||||
|
if isinstance(state, dict):
|
||||||
|
self.total_items = opt_get(state, ['total_items'], 0)
|
||||||
|
self.loaded_items = opt_get(state, ['loaded_items'], 0)
|
||||||
|
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
|
||||||
|
|
||||||
|
def update(self, batch):
|
||||||
|
self.total_items += batch['wav'].shape[0]
|
||||||
|
self.loaded_items += batch['skipped_items'].sum().item()
|
||||||
|
for filename in batch['filenames']:
|
||||||
|
self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
|
||||||
|
if 'conditioning' in batch.keys():
|
||||||
|
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
|
||||||
|
|
||||||
|
def get_debugging_map(self):
|
||||||
|
return {
|
||||||
|
'total_samples_loaded': self.total_items,
|
||||||
|
'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items,
|
||||||
|
'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items,
|
||||||
|
'unique_files_loaded': len(self.unique_files)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
batch_sz = 16
|
batch_sz = 16
|
||||||
params = {
|
params = {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user