Revise audio datasets to include interesting statistics in batch

Stats include:
- How many indices were skipped to retrieve a given index
- Whether or not a conditioning input was actually the file itself
This commit is contained in:
James Betker 2022-01-06 11:15:16 -07:00
parent 06c1093090
commit f3cab45658
2 changed files with 20 additions and 9 deletions

View File

@ -96,6 +96,7 @@ class TextWavLoader(torch.utils.data.Dataset):
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
else:
self.tokenizer = CharacterTokenizer()
self.skipped_items = 0 # records how many items are skipped when accessing an index.
def get_wav_text_pair(self, audiopath_and_text):
# separate filename and text
@ -115,14 +116,19 @@ class TextWavLoader(torch.utils.data.Dataset):
return tokens
def __getitem__(self, index):
self.skipped_items += 1
try:
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
cond = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate,
cond, cond_is_self = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate,
n=self.conditioning_candidates) if self.load_conditioning else None
except:
if self.skipped_items > 100:
raise # Rethrow if we have nested too far.
if self.debug_failures:
print(f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}")
return self[(index+1) % len(self)]
actually_skipped_items = self.skipped_items
self.skipped_items = 0
if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
@ -144,10 +150,12 @@ class TextWavLoader(torch.utils.data.Dataset):
'text_lengths': torch.tensor(orig_text_len, dtype=torch.long),
'wav': wav,
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
'filenames': path
'filenames': path,
'skipped_items': actually_skipped_items,
}
if self.load_conditioning:
res['conditioning'] = cond
res['conditioning_contains_self'] = cond_is_self
return res
def __len__(self):

View File

@ -49,7 +49,7 @@ def load_audio(audiopath, sampling_rate):
return audio.unsqueeze(0)
def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True, fallback_to_self=True):
def load_similar_clips(path, sample_length, sample_rate, n=3, fallback_to_self=True):
sim_path = os.path.join(os.path.dirname(path), 'similarities.pth')
candidates = []
if os.path.exists(sim_path):
@ -59,6 +59,7 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True,
candidates = [os.path.join(os.path.dirname(path), s) for s in similarities[fname]]
else:
print(f'Similarities list found for {path} but {fname} was not in that list.')
#candidates.append(path) # Always include self as a possible similar clip.
if len(candidates) == 0:
if fallback_to_self:
candidates = [path]
@ -66,16 +67,17 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True,
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
if not include_self:
candidates.remove(path)
if len(candidates) == 0:
print(f"No conditioning candidates found for {path}")
raise NotImplementedError()
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
related_clips = []
contains_self = False
for k in range(n):
rel_clip = load_audio(random.choice(candidates), sample_rate)
rel_path = random.choice(candidates)
contains_self = contains_self or (rel_path == path)
rel_clip = load_audio(rel_path, sample_rate)
gap = rel_clip.shape[-1] - sample_length
if gap < 0:
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
@ -84,9 +86,9 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True,
rel_clip = rel_clip[:, rand_start:rand_start+sample_length]
related_clips.append(rel_clip)
if n > 1:
return torch.stack(related_clips, dim=0)
return torch.stack(related_clips, dim=0), contains_self
else:
return related_clips[0]
return related_clips[0], contains_self
class UnsupervisedAudioDataset(torch.utils.data.Dataset):
@ -135,7 +137,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
try:
# Split audio_norm into two tensors of equal size.
audio_norm, filename = self.get_audio_for_index(index)
alt_files = self.get_related_audio_for_index(index)
alt_files, alt_is_self = self.get_related_audio_for_index(index)
except:
if self.debug_loading_failures:
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
@ -167,6 +169,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
output['resampled_clip'] = clips[1]
if self.extra_samples > 0:
output['alt_clips'] = alt_files
output['alt_contains_self'] = alt_is_self
return output
def __len__(self):