diff --git a/codes/data/audio/unsupervised_audio_dataset.py b/codes/data/audio/unsupervised_audio_dataset.py index 6fd9fee9..3277e1ef 100644 --- a/codes/data/audio/unsupervised_audio_dataset.py +++ b/codes/data/audio/unsupervised_audio_dataset.py @@ -98,8 +98,11 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): for exc in opt['exclusions']: with open(exc, 'r') as f: exclusions.extend(f.read().splitlines()) - ew = opt_get(opt, ['endswith']) - self.audiopaths = load_paths_from_cache(path, cache_path, exclusions, ew) + ew = opt_get(opt, ['endswith'], []) + assert isinstance(ew, list) + not_ew = opt_get(opt, ['not_endswith'], []) + assert isinstance(not_ew, list) + self.audiopaths = load_paths_from_cache(path, cache_path, exclusions, endswith=ew, not_endswith=not_ew) # Parse options self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050) diff --git a/codes/data/util.py b/codes/data/util.py index b3ce32b5..5ea94575 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -578,7 +578,7 @@ def imresize_np(img, scale, antialiasing=True): return out_2.numpy() -def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=None): +def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not_endswith=[]): if not isinstance(paths, list): paths = [paths] if os.path.exists(cache_path): @@ -597,8 +597,16 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=None): print(f"Excluded {before-len(output)} files.") if endswith is not None: before = len(output) - output = list(filter(lambda p: p.endswith(endswith), output)) - print(f"Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files") + def filter_fn(p): + for e in endswith: + if not p.endswith(endswith): + return False + for e in not_endswith: + if p.endswith(e): + return False + return True + output = list(filter(filter_fn, output)) + print(f"!!Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files") print("Done.") torch.save(output, cache_path) return output