From d6007c6de1aa6150069831891fd36cec11d7e234 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Thu, 5 Aug 2021 23:12:59 -0600
Subject: [PATCH] dataset fixes

---
 codes/data/audio/wav_aug.py         |  2 +-
 codes/data/audio/wavfile_dataset.py | 16 ++++++++--------
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/codes/data/audio/wav_aug.py b/codes/data/audio/wav_aug.py
index a6f4b7b2..58e0f2f7 100644
--- a/codes/data/audio/wav_aug.py
+++ b/codes/data/audio/wav_aug.py
@@ -53,7 +53,7 @@ class WavAugmentor:
 
 if __name__ == '__main__':
     sample, _ = load_wav_to_torch('obama1.wav')
-    sample = sample.permute(1,0) / 32768.0
+    sample = sample / 32768.0
     aug = WavAugmentor()
     for j in range(10):
         out = aug.augment(sample, 24000)
diff --git a/codes/data/audio/wavfile_dataset.py b/codes/data/audio/wavfile_dataset.py
index 2388876a..9917f83f 100644
--- a/codes/data/audio/wavfile_dataset.py
+++ b/codes/data/audio/wavfile_dataset.py
@@ -40,7 +40,7 @@ class WavfileDataset(torch.utils.data.Dataset):
         if sampling_rate != self.sampling_rate:
             raise ValueError(f"Input sampling rate does not match specified rate {self.sampling_rate}")
         audio_norm = audio / self.max_wav_value
-        audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
+        audio_norm = audio_norm.unsqueeze(0)
         return audio_norm, audiopath
 
     def __getitem__(self, index):
@@ -49,22 +49,22 @@ class WavfileDataset(torch.utils.data.Dataset):
         while clip1 is None and clip2 is None:
             # Split audio_norm into two tensors of equal size.
             audio_norm, filename = self.get_audio_for_index(index)
-            if audio_norm.shape[0] < self.window * 2:
+            if audio_norm.shape[1] < self.window * 2:
                 # Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
                 index = (index + 1) % len(self)
                 continue
-            j = random.randint(0, audio_norm.shape[0] - self.window)
-            clip1 = audio_norm[j:j+self.window]
+            j = random.randint(0, audio_norm.shape[1] - self.window)
+            clip1 = audio_norm[:, j:j+self.window]
             if self.augment:
                 clip1 = self.augmentor.augment(clip1, self.sampling_rate)
-            j = random.randint(0, audio_norm.shape[0]-self.window)
-            clip2 = audio_norm[j:j+self.window]
+            j = random.randint(0, audio_norm.shape[1]-self.window)
+            clip2 = audio_norm[:, j:j+self.window]
             if self.augment:
                 clip2 = self.augmentor.augment(clip2, self.sampling_rate)
 
         return {
-            'clip1': clip1.unsqueeze(0),
-            'clip2': clip2.unsqueeze(0),
+            'clip1': clip1,
+            'clip2': clip2,
             'path': filename,
         }