diff --git a/codes/data/audio/unsupervised_audio_dataset.py b/codes/data/audio/unsupervised_audio_dataset.py index 8a1b5f36..4f9b75ac 100644 --- a/codes/data/audio/unsupervised_audio_dataset.py +++ b/codes/data/audio/unsupervised_audio_dataset.py @@ -56,7 +56,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): def __init__(self, opt): path = opt['path'] cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case. - self.audiopaths = load_paths_from_cache(path, cache_path) + exclusions = [] + if 'exclusions' in opt.keys(): + for exc in opt['exclusions']: + with open(exc, 'r') as f: + exclusions.extend(f.read().splitlines()) + self.audiopaths = load_paths_from_cache(path, cache_path, exclusions) # Parse options self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050) diff --git a/codes/data/util.py b/codes/data/util.py index 6c370b96..743e9b89 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -578,7 +578,7 @@ def imresize_np(img, scale, antialiasing=True): return out_2.numpy() -def load_paths_from_cache(paths, cache_path): +def load_paths_from_cache(paths, cache_path, exclusion_list=[]): if not isinstance(paths, list): paths = [paths] if os.path.exists(cache_path): @@ -588,6 +588,10 @@ def load_paths_from_cache(paths, cache_path): output = [] for p in paths: output.extend(find_files_of_type('img', p, qualifier=is_audio_file)[0]) + if exclusion_list is not None and len(exclusion_list) > 0: + print(f"Removing exclusion lists..") + output = filter(lambda p: p not in exclusion_list, output) + print("Done.") torch.save(output, cache_path) return output