dataset fixes

This commit is contained in:
James Betker 2021-08-05 23:12:59 -06:00
parent 3ca51e80b2
commit d6007c6de1
2 changed files with 9 additions and 9 deletions

View File

@ -53,7 +53,7 @@ class WavAugmentor:
if __name__ == '__main__': if __name__ == '__main__':
sample, _ = load_wav_to_torch('obama1.wav') sample, _ = load_wav_to_torch('obama1.wav')
sample = sample.permute(1,0) / 32768.0 sample = sample / 32768.0
aug = WavAugmentor() aug = WavAugmentor()
for j in range(10): for j in range(10):
out = aug.augment(sample, 24000) out = aug.augment(sample, 24000)

View File

@ -40,7 +40,7 @@ class WavfileDataset(torch.utils.data.Dataset):
if sampling_rate != self.sampling_rate: if sampling_rate != self.sampling_rate:
raise ValueError(f"Input sampling rate does not match specified rate {self.sampling_rate}") raise ValueError(f"Input sampling rate does not match specified rate {self.sampling_rate}")
audio_norm = audio / self.max_wav_value audio_norm = audio / self.max_wav_value
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) audio_norm = audio_norm.unsqueeze(0)
return audio_norm, audiopath return audio_norm, audiopath
def __getitem__(self, index): def __getitem__(self, index):
@ -49,22 +49,22 @@ class WavfileDataset(torch.utils.data.Dataset):
while clip1 is None and clip2 is None: while clip1 is None and clip2 is None:
# Split audio_norm into two tensors of equal size. # Split audio_norm into two tensors of equal size.
audio_norm, filename = self.get_audio_for_index(index) audio_norm, filename = self.get_audio_for_index(index)
if audio_norm.shape[0] < self.window * 2: if audio_norm.shape[1] < self.window * 2:
# Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this. # Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
index = (index + 1) % len(self) index = (index + 1) % len(self)
continue continue
j = random.randint(0, audio_norm.shape[0] - self.window) j = random.randint(0, audio_norm.shape[1] - self.window)
clip1 = audio_norm[j:j+self.window] clip1 = audio_norm[:, j:j+self.window]
if self.augment: if self.augment:
clip1 = self.augmentor.augment(clip1, self.sampling_rate) clip1 = self.augmentor.augment(clip1, self.sampling_rate)
j = random.randint(0, audio_norm.shape[0]-self.window) j = random.randint(0, audio_norm.shape[1]-self.window)
clip2 = audio_norm[j:j+self.window] clip2 = audio_norm[:, j:j+self.window]
if self.augment: if self.augment:
clip2 = self.augmentor.augment(clip2, self.sampling_rate) clip2 = self.augmentor.augment(clip2, self.sampling_rate)
return { return {
'clip1': clip1.unsqueeze(0), 'clip1': clip1,
'clip2': clip2.unsqueeze(0), 'clip2': clip2,
'path': filename, 'path': filename,
} }