forked from mrq/DL-Art-School
Revise audio datasets to include interesting statistics in batch
Stats include: - How many indices were skipped to retrieve a given index - Whether or not a conditioning input was actually the file itself
This commit is contained in:
parent
06c1093090
commit
f3cab45658
|
@ -96,6 +96,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
||||
else:
|
||||
self.tokenizer = CharacterTokenizer()
|
||||
self.skipped_items = 0 # records how many items are skipped when accessing an index.
|
||||
|
||||
def get_wav_text_pair(self, audiopath_and_text):
|
||||
# separate filename and text
|
||||
|
@ -115,14 +116,19 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
return tokens
|
||||
|
||||
def __getitem__(self, index):
|
||||
self.skipped_items += 1
|
||||
try:
|
||||
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
|
||||
cond = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate,
|
||||
cond, cond_is_self = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate,
|
||||
n=self.conditioning_candidates) if self.load_conditioning else None
|
||||
except:
|
||||
if self.skipped_items > 100:
|
||||
raise # Rethrow if we have nested too far.
|
||||
if self.debug_failures:
|
||||
print(f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}")
|
||||
return self[(index+1) % len(self)]
|
||||
actually_skipped_items = self.skipped_items
|
||||
self.skipped_items = 0
|
||||
if wav is None or \
|
||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
|
@ -144,10 +150,12 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
'text_lengths': torch.tensor(orig_text_len, dtype=torch.long),
|
||||
'wav': wav,
|
||||
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||
'filenames': path
|
||||
'filenames': path,
|
||||
'skipped_items': actually_skipped_items,
|
||||
}
|
||||
if self.load_conditioning:
|
||||
res['conditioning'] = cond
|
||||
res['conditioning_contains_self'] = cond_is_self
|
||||
return res
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -49,7 +49,7 @@ def load_audio(audiopath, sampling_rate):
|
|||
return audio.unsqueeze(0)
|
||||
|
||||
|
||||
def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True, fallback_to_self=True):
|
||||
def load_similar_clips(path, sample_length, sample_rate, n=3, fallback_to_self=True):
|
||||
sim_path = os.path.join(os.path.dirname(path), 'similarities.pth')
|
||||
candidates = []
|
||||
if os.path.exists(sim_path):
|
||||
|
@ -59,6 +59,7 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True,
|
|||
candidates = [os.path.join(os.path.dirname(path), s) for s in similarities[fname]]
|
||||
else:
|
||||
print(f'Similarities list found for {path} but {fname} was not in that list.')
|
||||
#candidates.append(path) # Always include self as a possible similar clip.
|
||||
if len(candidates) == 0:
|
||||
if fallback_to_self:
|
||||
candidates = [path]
|
||||
|
@ -66,16 +67,17 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True,
|
|||
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
|
||||
|
||||
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
||||
if not include_self:
|
||||
candidates.remove(path)
|
||||
if len(candidates) == 0:
|
||||
print(f"No conditioning candidates found for {path}")
|
||||
raise NotImplementedError()
|
||||
|
||||
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
|
||||
related_clips = []
|
||||
contains_self = False
|
||||
for k in range(n):
|
||||
rel_clip = load_audio(random.choice(candidates), sample_rate)
|
||||
rel_path = random.choice(candidates)
|
||||
contains_self = contains_self or (rel_path == path)
|
||||
rel_clip = load_audio(rel_path, sample_rate)
|
||||
gap = rel_clip.shape[-1] - sample_length
|
||||
if gap < 0:
|
||||
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
|
||||
|
@ -84,9 +86,9 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True,
|
|||
rel_clip = rel_clip[:, rand_start:rand_start+sample_length]
|
||||
related_clips.append(rel_clip)
|
||||
if n > 1:
|
||||
return torch.stack(related_clips, dim=0)
|
||||
return torch.stack(related_clips, dim=0), contains_self
|
||||
else:
|
||||
return related_clips[0]
|
||||
return related_clips[0], contains_self
|
||||
|
||||
|
||||
class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||
|
@ -135,7 +137,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
try:
|
||||
# Split audio_norm into two tensors of equal size.
|
||||
audio_norm, filename = self.get_audio_for_index(index)
|
||||
alt_files = self.get_related_audio_for_index(index)
|
||||
alt_files, alt_is_self = self.get_related_audio_for_index(index)
|
||||
except:
|
||||
if self.debug_loading_failures:
|
||||
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
|
||||
|
@ -167,6 +169,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
output['resampled_clip'] = clips[1]
|
||||
if self.extra_samples > 0:
|
||||
output['alt_clips'] = alt_files
|
||||
output['alt_contains_self'] = alt_is_self
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user