diff --git a/codes/data/audio/paired_voice_audio_dataset.py b/codes/data/audio/paired_voice_audio_dataset.py index e2ae7f3c..5990cc55 100644 --- a/codes/data/audio/paired_voice_audio_dataset.py +++ b/codes/data/audio/paired_voice_audio_dataset.py @@ -96,6 +96,7 @@ class TextWavLoader(torch.utils.data.Dataset): self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json')) else: self.tokenizer = CharacterTokenizer() + self.skipped_items = 0 # records how many items are skipped when accessing an index. def get_wav_text_pair(self, audiopath_and_text): # separate filename and text @@ -115,14 +116,19 @@ class TextWavLoader(torch.utils.data.Dataset): return tokens def __getitem__(self, index): + self.skipped_items += 1 try: tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index]) - cond = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate, + cond, cond_is_self = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate, n=self.conditioning_candidates) if self.load_conditioning else None except: + if self.skipped_items > 100: + raise # Rethrow if we have nested too far. if self.debug_failures: print(f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}") return self[(index+1) % len(self)] + actually_skipped_items = self.skipped_items + self.skipped_items = 0 if wav is None or \ (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): @@ -144,10 +150,12 @@ class TextWavLoader(torch.utils.data.Dataset): 'text_lengths': torch.tensor(orig_text_len, dtype=torch.long), 'wav': wav, 'wav_lengths': torch.tensor(orig_output, dtype=torch.long), - 'filenames': path + 'filenames': path, + 'skipped_items': actually_skipped_items, } if self.load_conditioning: res['conditioning'] = cond + res['conditioning_contains_self'] = cond_is_self return res def __len__(self): diff --git a/codes/data/audio/unsupervised_audio_dataset.py b/codes/data/audio/unsupervised_audio_dataset.py index dd0ab5c1..ab65da62 100644 --- a/codes/data/audio/unsupervised_audio_dataset.py +++ b/codes/data/audio/unsupervised_audio_dataset.py @@ -49,7 +49,7 @@ def load_audio(audiopath, sampling_rate): return audio.unsqueeze(0) -def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True, fallback_to_self=True): +def load_similar_clips(path, sample_length, sample_rate, n=3, fallback_to_self=True): sim_path = os.path.join(os.path.dirname(path), 'similarities.pth') candidates = [] if os.path.exists(sim_path): @@ -59,6 +59,7 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True, candidates = [os.path.join(os.path.dirname(path), s) for s in similarities[fname]] else: print(f'Similarities list found for {path} but {fname} was not in that list.') + #candidates.append(path) # Always include self as a possible similar clip. if len(candidates) == 0: if fallback_to_self: candidates = [path] @@ -66,16 +67,17 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True, candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0] assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related. - if not include_self: - candidates.remove(path) if len(candidates) == 0: print(f"No conditioning candidates found for {path}") raise NotImplementedError() # Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates. related_clips = [] + contains_self = False for k in range(n): - rel_clip = load_audio(random.choice(candidates), sample_rate) + rel_path = random.choice(candidates) + contains_self = contains_self or (rel_path == path) + rel_clip = load_audio(rel_path, sample_rate) gap = rel_clip.shape[-1] - sample_length if gap < 0: rel_clip = F.pad(rel_clip, pad=(0, abs(gap))) @@ -84,9 +86,9 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True, rel_clip = rel_clip[:, rand_start:rand_start+sample_length] related_clips.append(rel_clip) if n > 1: - return torch.stack(related_clips, dim=0) + return torch.stack(related_clips, dim=0), contains_self else: - return related_clips[0] + return related_clips[0], contains_self class UnsupervisedAudioDataset(torch.utils.data.Dataset): @@ -135,7 +137,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): try: # Split audio_norm into two tensors of equal size. audio_norm, filename = self.get_audio_for_index(index) - alt_files = self.get_related_audio_for_index(index) + alt_files, alt_is_self = self.get_related_audio_for_index(index) except: if self.debug_loading_failures: print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}") @@ -167,6 +169,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): output['resampled_clip'] = clips[1] if self.extra_samples > 0: output['alt_clips'] = alt_files + output['alt_contains_self'] = alt_is_self return output def __len__(self):