forked from mrq/DL-Art-School
More fixes
This commit is contained in:
parent
9a9c90660f
commit
8d9857f33d
|
@ -84,12 +84,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
assert audiopath in related_files
|
assert audiopath in related_files
|
||||||
assert len(related_files) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
assert len(related_files) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
||||||
if len(related_files) == 0:
|
if len(related_files) == 0:
|
||||||
j = 0
|
|
||||||
print(f"No related files for {audiopath}")
|
print(f"No related files for {audiopath}")
|
||||||
related_files.remove(audiopath)
|
related_files.remove(audiopath)
|
||||||
related_clips = []
|
related_clips = []
|
||||||
random.shuffle(related_clips)
|
random.shuffle(related_clips)
|
||||||
for j, related_file in enumerate(related_files):
|
i = 0
|
||||||
|
for related_file in related_files:
|
||||||
rel_clip = load_audio(related_file, self.sampling_rate)
|
rel_clip = load_audio(related_file, self.sampling_rate)
|
||||||
gap = rel_clip.shape[-1] - self.extra_sample_len
|
gap = rel_clip.shape[-1] - self.extra_sample_len
|
||||||
if gap < 0:
|
if gap < 0:
|
||||||
|
@ -98,12 +98,13 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
rand_start = random.randint(0, gap)
|
rand_start = random.randint(0, gap)
|
||||||
rel_clip = rel_clip[:, rand_start:rand_start+self.extra_sample_len]
|
rel_clip = rel_clip[:, rand_start:rand_start+self.extra_sample_len]
|
||||||
related_clips.append(rel_clip)
|
related_clips.append(rel_clip)
|
||||||
if j >= self.extra_samples:
|
i += 1
|
||||||
|
if i >= self.extra_samples:
|
||||||
break
|
break
|
||||||
actual_extra_samples = j
|
actual_extra_samples = i
|
||||||
while j < self.extra_samples:
|
while i < self.extra_samples:
|
||||||
related_clips.append(torch.zeros(1, self.extra_sample_len))
|
related_clips.append(torch.zeros(1, self.extra_sample_len))
|
||||||
j += 1
|
i += 1
|
||||||
return torch.stack(related_clips, dim=0), actual_extra_samples
|
return torch.stack(related_clips, dim=0), actual_extra_samples
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user