Rework wavfile dataset to be usable for things other than augments

This commit is contained in:
James Betker 2021-08-16 22:52:35 -06:00
parent d7f30232c3
commit 93e903af15

View File

@ -30,13 +30,14 @@ class WavfileDataset(torch.utils.data.Dataset):
# Parse options
self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000)
self.augment = opt_get(opt, ['do_augmentation'], False)
self.pad_to = opt_get(opt, ['pad_to_seconds'], None)
if self.pad_to is not None:
self.pad_to *= self.sampling_rate
self.window = 2 * self.sampling_rate
self.augment = opt_get(opt, ['do_augmentation'], False)
if self.augment:
# The "window size" for the clips produced in seconds.
self.window = 2 * self.sampling_rate
self.augmentor = WavAugmentor()
def get_audio_for_index(self, index):
@ -57,15 +58,21 @@ class WavfileDataset(torch.utils.data.Dataset):
return audio, audiopath
def __getitem__(self, index):
clip1, clip2 = None, None
success = False
# This "success" thing is a hack: This dataset is randomly failing for no apparent good reason and I don't know why.
# Symptoms are it complaining about being unable to read a nonsensical filename that is clearly corrupted. Memory corruption? I don't know..
while not success:
try:
# Split audio_norm into two tensors of equal size.
audio_norm, filename = self.get_audio_for_index(index)
success = True
except:
print(f"Failed to load {index} {self.audiopaths[index]}")
while clip1 is None and clip2 is None:
# Split audio_norm into two tensors of equal size.
audio_norm, filename = self.get_audio_for_index(index)
if self.augment:
if audio_norm.shape[1] < self.window * 2:
# Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
index = (index + 1) % len(self)
continue
return self[(index + 1) % len(self)]
j = random.randint(0, audio_norm.shape[1] - self.window)
clip1 = audio_norm[:, j:j+self.window]
if self.augment:
@ -83,12 +90,16 @@ class WavfileDataset(torch.utils.data.Dataset):
#print(f"Warning! Truncating clip {filename} from {audio_norm.shape[-1]} to {self.pad_to}")
audio_norm = audio_norm[:, :self.pad_to]
return {
output = {
'clip': audio_norm,
'clip1': clip1[0, :].unsqueeze(0),
'clip2': clip2[0, :].unsqueeze(0),
'path': filename,
}
if self.augment:
output.update({
'clip1': clip1[0, :].unsqueeze(0),
'clip2': clip2[0, :].unsqueeze(0),
})
return output
def __len__(self):
return len(self.audiopaths)
@ -104,6 +115,7 @@ if __name__ == '__main__':
'phase': 'train',
'n_workers': 0,
'batch_size': 16,
'do_augmentation': False
}
from data import create_dataset, create_dataloader, util