diff --git a/codes/data/audio/unsupervised_audio_dataset.py b/codes/data/audio/unsupervised_audio_dataset.py index 5b3115e1..d764a83d 100644 --- a/codes/data/audio/unsupervised_audio_dataset.py +++ b/codes/data/audio/unsupervised_audio_dataset.py @@ -84,12 +84,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): assert audiopath in related_files assert len(related_files) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related. if len(related_files) == 0: - j = 0 print(f"No related files for {audiopath}") related_files.remove(audiopath) related_clips = [] random.shuffle(related_clips) - for j, related_file in enumerate(related_files): + i = 0 + for related_file in related_files: rel_clip = load_audio(related_file, self.sampling_rate) gap = rel_clip.shape[-1] - self.extra_sample_len if gap < 0: @@ -98,12 +98,13 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): rand_start = random.randint(0, gap) rel_clip = rel_clip[:, rand_start:rand_start+self.extra_sample_len] related_clips.append(rel_clip) - if j >= self.extra_samples: + i += 1 + if i >= self.extra_samples: break - actual_extra_samples = j - while j < self.extra_samples: + actual_extra_samples = i + while i < self.extra_samples: related_clips.append(torch.zeros(1, self.extra_sample_len)) - j += 1 + i += 1 return torch.stack(related_clips, dim=0), actual_extra_samples def __getitem__(self, index):