From d6007c6de1aa6150069831891fd36cec11d7e234 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 5 Aug 2021 23:12:59 -0600 Subject: [PATCH] dataset fixes --- codes/data/audio/wav_aug.py | 2 +- codes/data/audio/wavfile_dataset.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/codes/data/audio/wav_aug.py b/codes/data/audio/wav_aug.py index a6f4b7b2..58e0f2f7 100644 --- a/codes/data/audio/wav_aug.py +++ b/codes/data/audio/wav_aug.py @@ -53,7 +53,7 @@ class WavAugmentor: if __name__ == '__main__': sample, _ = load_wav_to_torch('obama1.wav') - sample = sample.permute(1,0) / 32768.0 + sample = sample / 32768.0 aug = WavAugmentor() for j in range(10): out = aug.augment(sample, 24000) diff --git a/codes/data/audio/wavfile_dataset.py b/codes/data/audio/wavfile_dataset.py index 2388876a..9917f83f 100644 --- a/codes/data/audio/wavfile_dataset.py +++ b/codes/data/audio/wavfile_dataset.py @@ -40,7 +40,7 @@ class WavfileDataset(torch.utils.data.Dataset): if sampling_rate != self.sampling_rate: raise ValueError(f"Input sampling rate does not match specified rate {self.sampling_rate}") audio_norm = audio / self.max_wav_value - audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) + audio_norm = audio_norm.unsqueeze(0) return audio_norm, audiopath def __getitem__(self, index): @@ -49,22 +49,22 @@ class WavfileDataset(torch.utils.data.Dataset): while clip1 is None and clip2 is None: # Split audio_norm into two tensors of equal size. audio_norm, filename = self.get_audio_for_index(index) - if audio_norm.shape[0] < self.window * 2: + if audio_norm.shape[1] < self.window * 2: # Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this. index = (index + 1) % len(self) continue - j = random.randint(0, audio_norm.shape[0] - self.window) - clip1 = audio_norm[j:j+self.window] + j = random.randint(0, audio_norm.shape[1] - self.window) + clip1 = audio_norm[:, j:j+self.window] if self.augment: clip1 = self.augmentor.augment(clip1, self.sampling_rate) - j = random.randint(0, audio_norm.shape[0]-self.window) - clip2 = audio_norm[j:j+self.window] + j = random.randint(0, audio_norm.shape[1]-self.window) + clip2 = audio_norm[:, j:j+self.window] if self.augment: clip2 = self.augmentor.augment(clip2, self.sampling_rate) return { - 'clip1': clip1.unsqueeze(0), - 'clip2': clip2.unsqueeze(0), + 'clip1': clip1, + 'clip2': clip2, 'path': filename, }