dataset fixes
This commit is contained in:
parent
3ca51e80b2
commit
d6007c6de1
|
@ -53,7 +53,7 @@ class WavAugmentor:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
sample, _ = load_wav_to_torch('obama1.wav')
|
sample, _ = load_wav_to_torch('obama1.wav')
|
||||||
sample = sample.permute(1,0) / 32768.0
|
sample = sample / 32768.0
|
||||||
aug = WavAugmentor()
|
aug = WavAugmentor()
|
||||||
for j in range(10):
|
for j in range(10):
|
||||||
out = aug.augment(sample, 24000)
|
out = aug.augment(sample, 24000)
|
||||||
|
|
|
@ -40,7 +40,7 @@ class WavfileDataset(torch.utils.data.Dataset):
|
||||||
if sampling_rate != self.sampling_rate:
|
if sampling_rate != self.sampling_rate:
|
||||||
raise ValueError(f"Input sampling rate does not match specified rate {self.sampling_rate}")
|
raise ValueError(f"Input sampling rate does not match specified rate {self.sampling_rate}")
|
||||||
audio_norm = audio / self.max_wav_value
|
audio_norm = audio / self.max_wav_value
|
||||||
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
return audio_norm, audiopath
|
return audio_norm, audiopath
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
@ -49,22 +49,22 @@ class WavfileDataset(torch.utils.data.Dataset):
|
||||||
while clip1 is None and clip2 is None:
|
while clip1 is None and clip2 is None:
|
||||||
# Split audio_norm into two tensors of equal size.
|
# Split audio_norm into two tensors of equal size.
|
||||||
audio_norm, filename = self.get_audio_for_index(index)
|
audio_norm, filename = self.get_audio_for_index(index)
|
||||||
if audio_norm.shape[0] < self.window * 2:
|
if audio_norm.shape[1] < self.window * 2:
|
||||||
# Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
|
# Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
|
||||||
index = (index + 1) % len(self)
|
index = (index + 1) % len(self)
|
||||||
continue
|
continue
|
||||||
j = random.randint(0, audio_norm.shape[0] - self.window)
|
j = random.randint(0, audio_norm.shape[1] - self.window)
|
||||||
clip1 = audio_norm[j:j+self.window]
|
clip1 = audio_norm[:, j:j+self.window]
|
||||||
if self.augment:
|
if self.augment:
|
||||||
clip1 = self.augmentor.augment(clip1, self.sampling_rate)
|
clip1 = self.augmentor.augment(clip1, self.sampling_rate)
|
||||||
j = random.randint(0, audio_norm.shape[0]-self.window)
|
j = random.randint(0, audio_norm.shape[1]-self.window)
|
||||||
clip2 = audio_norm[j:j+self.window]
|
clip2 = audio_norm[:, j:j+self.window]
|
||||||
if self.augment:
|
if self.augment:
|
||||||
clip2 = self.augmentor.augment(clip2, self.sampling_rate)
|
clip2 = self.augmentor.augment(clip2, self.sampling_rate)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'clip1': clip1.unsqueeze(0),
|
'clip1': clip1,
|
||||||
'clip2': clip2.unsqueeze(0),
|
'clip2': clip2,
|
||||||
'path': filename,
|
'path': filename,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user