support not_ew too

This commit is contained in:
James Betker 2022-05-25 08:58:23 -06:00
parent 5188866bd5
commit 00e133afa9
2 changed files with 16 additions and 5 deletions

View File

@ -98,8 +98,11 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
for exc in opt['exclusions']:
with open(exc, 'r') as f:
exclusions.extend(f.read().splitlines())
ew = opt_get(opt, ['endswith'])
self.audiopaths = load_paths_from_cache(path, cache_path, exclusions, ew)
ew = opt_get(opt, ['endswith'], [])
assert isinstance(ew, list)
not_ew = opt_get(opt, ['not_endswith'], [])
assert isinstance(not_ew, list)
self.audiopaths = load_paths_from_cache(path, cache_path, exclusions, endswith=ew, not_endswith=not_ew)
# Parse options
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)

View File

@ -578,7 +578,7 @@ def imresize_np(img, scale, antialiasing=True):
return out_2.numpy()
def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=None):
def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not_endswith=[]):
if not isinstance(paths, list):
paths = [paths]
if os.path.exists(cache_path):
@ -597,8 +597,16 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=None):
print(f"Excluded {before-len(output)} files.")
if endswith is not None:
before = len(output)
output = list(filter(lambda p: p.endswith(endswith), output))
print(f"Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files")
def filter_fn(p):
for e in endswith:
if not p.endswith(endswith):
return False
for e in not_endswith:
if p.endswith(e):
return False
return True
output = list(filter(filter_fn, output))
print(f"!!Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files")
print("Done.")
torch.save(output, cache_path)
return output