Add exclusion_lists to unsupervised_audio_dataset

This commit is contained in:
James Betker 2021-11-07 18:46:47 -07:00
parent 9b693b0a54
commit 18b1de9b2c
2 changed files with 11 additions and 2 deletions

View File

@ -56,7 +56,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
def __init__(self, opt):
path = opt['path']
cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case.
self.audiopaths = load_paths_from_cache(path, cache_path)
exclusions = []
if 'exclusions' in opt.keys():
for exc in opt['exclusions']:
with open(exc, 'r') as f:
exclusions.extend(f.read().splitlines())
self.audiopaths = load_paths_from_cache(path, cache_path, exclusions)
# Parse options
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)

View File

@ -578,7 +578,7 @@ def imresize_np(img, scale, antialiasing=True):
return out_2.numpy()
def load_paths_from_cache(paths, cache_path):
def load_paths_from_cache(paths, cache_path, exclusion_list=[]):
if not isinstance(paths, list):
paths = [paths]
if os.path.exists(cache_path):
@ -588,6 +588,10 @@ def load_paths_from_cache(paths, cache_path):
output = []
for p in paths:
output.extend(find_files_of_type('img', p, qualifier=is_audio_file)[0])
if exclusion_list is not None and len(exclusion_list) > 0:
print(f"Removing exclusion lists..")
output = filter(lambda p: p not in exclusion_list, output)
print("Done.")
torch.save(output, cache_path)
return output