Rework wavfile dataset to be usable for things other than augments
This commit is contained in:
parent
d7f30232c3
commit
93e903af15
|
@ -30,13 +30,14 @@ class WavfileDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
# Parse options
|
# Parse options
|
||||||
self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000)
|
self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000)
|
||||||
self.augment = opt_get(opt, ['do_augmentation'], False)
|
|
||||||
self.pad_to = opt_get(opt, ['pad_to_seconds'], None)
|
self.pad_to = opt_get(opt, ['pad_to_seconds'], None)
|
||||||
if self.pad_to is not None:
|
if self.pad_to is not None:
|
||||||
self.pad_to *= self.sampling_rate
|
self.pad_to *= self.sampling_rate
|
||||||
|
|
||||||
self.window = 2 * self.sampling_rate
|
self.augment = opt_get(opt, ['do_augmentation'], False)
|
||||||
if self.augment:
|
if self.augment:
|
||||||
|
# The "window size" for the clips produced in seconds.
|
||||||
|
self.window = 2 * self.sampling_rate
|
||||||
self.augmentor = WavAugmentor()
|
self.augmentor = WavAugmentor()
|
||||||
|
|
||||||
def get_audio_for_index(self, index):
|
def get_audio_for_index(self, index):
|
||||||
|
@ -57,15 +58,21 @@ class WavfileDataset(torch.utils.data.Dataset):
|
||||||
return audio, audiopath
|
return audio, audiopath
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
clip1, clip2 = None, None
|
success = False
|
||||||
|
# This "success" thing is a hack: This dataset is randomly failing for no apparent good reason and I don't know why.
|
||||||
|
# Symptoms are it complaining about being unable to read a nonsensical filename that is clearly corrupted. Memory corruption? I don't know..
|
||||||
|
while not success:
|
||||||
|
try:
|
||||||
|
# Split audio_norm into two tensors of equal size.
|
||||||
|
audio_norm, filename = self.get_audio_for_index(index)
|
||||||
|
success = True
|
||||||
|
except:
|
||||||
|
print(f"Failed to load {index} {self.audiopaths[index]}")
|
||||||
|
|
||||||
while clip1 is None and clip2 is None:
|
if self.augment:
|
||||||
# Split audio_norm into two tensors of equal size.
|
|
||||||
audio_norm, filename = self.get_audio_for_index(index)
|
|
||||||
if audio_norm.shape[1] < self.window * 2:
|
if audio_norm.shape[1] < self.window * 2:
|
||||||
# Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
|
# Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
|
||||||
index = (index + 1) % len(self)
|
return self[(index + 1) % len(self)]
|
||||||
continue
|
|
||||||
j = random.randint(0, audio_norm.shape[1] - self.window)
|
j = random.randint(0, audio_norm.shape[1] - self.window)
|
||||||
clip1 = audio_norm[:, j:j+self.window]
|
clip1 = audio_norm[:, j:j+self.window]
|
||||||
if self.augment:
|
if self.augment:
|
||||||
|
@ -83,12 +90,16 @@ class WavfileDataset(torch.utils.data.Dataset):
|
||||||
#print(f"Warning! Truncating clip {filename} from {audio_norm.shape[-1]} to {self.pad_to}")
|
#print(f"Warning! Truncating clip {filename} from {audio_norm.shape[-1]} to {self.pad_to}")
|
||||||
audio_norm = audio_norm[:, :self.pad_to]
|
audio_norm = audio_norm[:, :self.pad_to]
|
||||||
|
|
||||||
return {
|
output = {
|
||||||
'clip': audio_norm,
|
'clip': audio_norm,
|
||||||
'clip1': clip1[0, :].unsqueeze(0),
|
|
||||||
'clip2': clip2[0, :].unsqueeze(0),
|
|
||||||
'path': filename,
|
'path': filename,
|
||||||
}
|
}
|
||||||
|
if self.augment:
|
||||||
|
output.update({
|
||||||
|
'clip1': clip1[0, :].unsqueeze(0),
|
||||||
|
'clip2': clip2[0, :].unsqueeze(0),
|
||||||
|
})
|
||||||
|
return output
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.audiopaths)
|
return len(self.audiopaths)
|
||||||
|
@ -104,6 +115,7 @@ if __name__ == '__main__':
|
||||||
'phase': 'train',
|
'phase': 'train',
|
||||||
'n_workers': 0,
|
'n_workers': 0,
|
||||||
'batch_size': 16,
|
'batch_size': 16,
|
||||||
|
'do_augmentation': False
|
||||||
}
|
}
|
||||||
from data import create_dataset, create_dataloader, util
|
from data import create_dataset, create_dataloader, util
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user