forked from mrq/DL-Art-School
support not_ew too
This commit is contained in:
parent
5188866bd5
commit
00e133afa9
|
@ -98,8 +98,11 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
for exc in opt['exclusions']:
|
for exc in opt['exclusions']:
|
||||||
with open(exc, 'r') as f:
|
with open(exc, 'r') as f:
|
||||||
exclusions.extend(f.read().splitlines())
|
exclusions.extend(f.read().splitlines())
|
||||||
ew = opt_get(opt, ['endswith'])
|
ew = opt_get(opt, ['endswith'], [])
|
||||||
self.audiopaths = load_paths_from_cache(path, cache_path, exclusions, ew)
|
assert isinstance(ew, list)
|
||||||
|
not_ew = opt_get(opt, ['not_endswith'], [])
|
||||||
|
assert isinstance(not_ew, list)
|
||||||
|
self.audiopaths = load_paths_from_cache(path, cache_path, exclusions, endswith=ew, not_endswith=not_ew)
|
||||||
|
|
||||||
# Parse options
|
# Parse options
|
||||||
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
|
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
|
||||||
|
|
|
@ -578,7 +578,7 @@ def imresize_np(img, scale, antialiasing=True):
|
||||||
return out_2.numpy()
|
return out_2.numpy()
|
||||||
|
|
||||||
|
|
||||||
def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=None):
|
def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not_endswith=[]):
|
||||||
if not isinstance(paths, list):
|
if not isinstance(paths, list):
|
||||||
paths = [paths]
|
paths = [paths]
|
||||||
if os.path.exists(cache_path):
|
if os.path.exists(cache_path):
|
||||||
|
@ -597,8 +597,16 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=None):
|
||||||
print(f"Excluded {before-len(output)} files.")
|
print(f"Excluded {before-len(output)} files.")
|
||||||
if endswith is not None:
|
if endswith is not None:
|
||||||
before = len(output)
|
before = len(output)
|
||||||
output = list(filter(lambda p: p.endswith(endswith), output))
|
def filter_fn(p):
|
||||||
print(f"Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files")
|
for e in endswith:
|
||||||
|
if not p.endswith(endswith):
|
||||||
|
return False
|
||||||
|
for e in not_endswith:
|
||||||
|
if p.endswith(e):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
output = list(filter(filter_fn, output))
|
||||||
|
print(f"!!Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files")
|
||||||
print("Done.")
|
print("Done.")
|
||||||
torch.save(output, cache_path)
|
torch.save(output, cache_path)
|
||||||
return output
|
return output
|
||||||
|
|
Loading…
Reference in New Issue
Block a user