Add exclusion_lists to unsupervised_audio_dataset
This commit is contained in:
parent
9b693b0a54
commit
18b1de9b2c
|
@ -56,7 +56,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
def __init__(self, opt):
|
||||
path = opt['path']
|
||||
cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case.
|
||||
self.audiopaths = load_paths_from_cache(path, cache_path)
|
||||
exclusions = []
|
||||
if 'exclusions' in opt.keys():
|
||||
for exc in opt['exclusions']:
|
||||
with open(exc, 'r') as f:
|
||||
exclusions.extend(f.read().splitlines())
|
||||
self.audiopaths = load_paths_from_cache(path, cache_path, exclusions)
|
||||
|
||||
# Parse options
|
||||
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
|
||||
|
|
|
@ -578,7 +578,7 @@ def imresize_np(img, scale, antialiasing=True):
|
|||
return out_2.numpy()
|
||||
|
||||
|
||||
def load_paths_from_cache(paths, cache_path):
|
||||
def load_paths_from_cache(paths, cache_path, exclusion_list=[]):
|
||||
if not isinstance(paths, list):
|
||||
paths = [paths]
|
||||
if os.path.exists(cache_path):
|
||||
|
@ -588,6 +588,10 @@ def load_paths_from_cache(paths, cache_path):
|
|||
output = []
|
||||
for p in paths:
|
||||
output.extend(find_files_of_type('img', p, qualifier=is_audio_file)[0])
|
||||
if exclusion_list is not None and len(exclusion_list) > 0:
|
||||
print(f"Removing exclusion lists..")
|
||||
output = filter(lambda p: p not in exclusion_list, output)
|
||||
print("Done.")
|
||||
torch.save(output, cache_path)
|
||||
return output
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user