forked from mrq/DL-Art-School
Rework wavfile dataset to be usable for things other than augments
This commit is contained in:
parent
d7f30232c3
commit
93e903af15
|
@ -30,13 +30,14 @@ class WavfileDataset(torch.utils.data.Dataset):
|
|||
|
||||
# Parse options
|
||||
self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000)
|
||||
self.augment = opt_get(opt, ['do_augmentation'], False)
|
||||
self.pad_to = opt_get(opt, ['pad_to_seconds'], None)
|
||||
if self.pad_to is not None:
|
||||
self.pad_to *= self.sampling_rate
|
||||
|
||||
self.window = 2 * self.sampling_rate
|
||||
self.augment = opt_get(opt, ['do_augmentation'], False)
|
||||
if self.augment:
|
||||
# The "window size" for the clips produced in seconds.
|
||||
self.window = 2 * self.sampling_rate
|
||||
self.augmentor = WavAugmentor()
|
||||
|
||||
def get_audio_for_index(self, index):
|
||||
|
@ -57,15 +58,21 @@ class WavfileDataset(torch.utils.data.Dataset):
|
|||
return audio, audiopath
|
||||
|
||||
def __getitem__(self, index):
|
||||
clip1, clip2 = None, None
|
||||
success = False
|
||||
# This "success" thing is a hack: This dataset is randomly failing for no apparent good reason and I don't know why.
|
||||
# Symptoms are it complaining about being unable to read a nonsensical filename that is clearly corrupted. Memory corruption? I don't know..
|
||||
while not success:
|
||||
try:
|
||||
# Split audio_norm into two tensors of equal size.
|
||||
audio_norm, filename = self.get_audio_for_index(index)
|
||||
success = True
|
||||
except:
|
||||
print(f"Failed to load {index} {self.audiopaths[index]}")
|
||||
|
||||
while clip1 is None and clip2 is None:
|
||||
# Split audio_norm into two tensors of equal size.
|
||||
audio_norm, filename = self.get_audio_for_index(index)
|
||||
if self.augment:
|
||||
if audio_norm.shape[1] < self.window * 2:
|
||||
# Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
|
||||
index = (index + 1) % len(self)
|
||||
continue
|
||||
return self[(index + 1) % len(self)]
|
||||
j = random.randint(0, audio_norm.shape[1] - self.window)
|
||||
clip1 = audio_norm[:, j:j+self.window]
|
||||
if self.augment:
|
||||
|
@ -83,12 +90,16 @@ class WavfileDataset(torch.utils.data.Dataset):
|
|||
#print(f"Warning! Truncating clip {filename} from {audio_norm.shape[-1]} to {self.pad_to}")
|
||||
audio_norm = audio_norm[:, :self.pad_to]
|
||||
|
||||
return {
|
||||
output = {
|
||||
'clip': audio_norm,
|
||||
'clip1': clip1[0, :].unsqueeze(0),
|
||||
'clip2': clip2[0, :].unsqueeze(0),
|
||||
'path': filename,
|
||||
}
|
||||
if self.augment:
|
||||
output.update({
|
||||
'clip1': clip1[0, :].unsqueeze(0),
|
||||
'clip2': clip2[0, :].unsqueeze(0),
|
||||
})
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths)
|
||||
|
@ -104,6 +115,7 @@ if __name__ == '__main__':
|
|||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'batch_size': 16,
|
||||
'do_augmentation': False
|
||||
}
|
||||
from data import create_dataset, create_dataloader, util
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user