From 93e903af15d507b5ebd78923c4fe13268df992ee Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 16 Aug 2021 22:52:35 -0600 Subject: [PATCH] Rework wavfile dataset to be usable for things other than augments --- codes/data/audio/wavfile_dataset.py | 34 +++++++++++++++++++---------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/codes/data/audio/wavfile_dataset.py b/codes/data/audio/wavfile_dataset.py index 4310494c..707ef136 100644 --- a/codes/data/audio/wavfile_dataset.py +++ b/codes/data/audio/wavfile_dataset.py @@ -30,13 +30,14 @@ class WavfileDataset(torch.utils.data.Dataset): # Parse options self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000) - self.augment = opt_get(opt, ['do_augmentation'], False) self.pad_to = opt_get(opt, ['pad_to_seconds'], None) if self.pad_to is not None: self.pad_to *= self.sampling_rate - self.window = 2 * self.sampling_rate + self.augment = opt_get(opt, ['do_augmentation'], False) if self.augment: + # The "window size" for the clips produced in seconds. + self.window = 2 * self.sampling_rate self.augmentor = WavAugmentor() def get_audio_for_index(self, index): @@ -57,15 +58,21 @@ class WavfileDataset(torch.utils.data.Dataset): return audio, audiopath def __getitem__(self, index): - clip1, clip2 = None, None + success = False + # This "success" thing is a hack: This dataset is randomly failing for no apparent good reason and I don't know why. + # Symptoms are it complaining about being unable to read a nonsensical filename that is clearly corrupted. Memory corruption? I don't know.. + while not success: + try: + # Split audio_norm into two tensors of equal size. + audio_norm, filename = self.get_audio_for_index(index) + success = True + except: + print(f"Failed to load {index} {self.audiopaths[index]}") - while clip1 is None and clip2 is None: - # Split audio_norm into two tensors of equal size. - audio_norm, filename = self.get_audio_for_index(index) + if self.augment: if audio_norm.shape[1] < self.window * 2: # Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this. - index = (index + 1) % len(self) - continue + return self[(index + 1) % len(self)] j = random.randint(0, audio_norm.shape[1] - self.window) clip1 = audio_norm[:, j:j+self.window] if self.augment: @@ -83,12 +90,16 @@ class WavfileDataset(torch.utils.data.Dataset): #print(f"Warning! Truncating clip {filename} from {audio_norm.shape[-1]} to {self.pad_to}") audio_norm = audio_norm[:, :self.pad_to] - return { + output = { 'clip': audio_norm, - 'clip1': clip1[0, :].unsqueeze(0), - 'clip2': clip2[0, :].unsqueeze(0), 'path': filename, } + if self.augment: + output.update({ + 'clip1': clip1[0, :].unsqueeze(0), + 'clip2': clip2[0, :].unsqueeze(0), + }) + return output def __len__(self): return len(self.audiopaths) @@ -104,6 +115,7 @@ if __name__ == '__main__': 'phase': 'train', 'n_workers': 0, 'batch_size': 16, + 'do_augmentation': False } from data import create_dataset, create_dataloader, util