diff --git a/codes/data/audio/wavfile_dataset.py b/codes/data/audio/wavfile_dataset.py index 17abe266..1c7e8199 100644 --- a/codes/data/audio/wavfile_dataset.py +++ b/codes/data/audio/wavfile_dataset.py @@ -33,7 +33,7 @@ class WavfileDataset(torch.utils.data.Dataset): self.pad_to = opt_get(opt, ['pad_to_seconds'], None) if self.pad_to is not None: self.pad_to *= self.sampling_rate - self.min_sz = opt_get(opt, ['minimum_samples'], 0) + self.pad_to = opt_get(opt, ['pad_to_samples'], self.pad_to) self.augment = opt_get(opt, ['do_augmentation'], False) if self.augment: @@ -88,8 +88,9 @@ class WavfileDataset(torch.utils.data.Dataset): if audio_norm.shape[-1] <= self.pad_to: audio_norm = torch.nn.functional.pad(audio_norm, (0, self.pad_to - audio_norm.shape[-1])) else: - #print(f"Warning! Truncating clip {filename} from {audio_norm.shape[-1]} to {self.pad_to}") - audio_norm = audio_norm[:, :self.pad_to] + gap = audio_norm.shape[-1] - self.pad_to + start = random.randint(0, gap-1) + audio_norm = audio_norm[:, start:start+self.pad_to] # Bail and try the next clip if there is not enough data. if audio_norm.shape[-1] < self.min_sz: