Work on spleeter filtering script

This commit is contained in:
James Betker 2021-09-29 09:24:56 -06:00
parent 55b58fb67f
commit fc8ae4679a
2 changed files with 114 additions and 0 deletions

View File

@ -0,0 +1,112 @@
from scipy.io import wavfile
import os
import argparse
import numpy as np
from scipy.io import wavfile
from spleeter.separator import Separator
from tqdm import tqdm
from spleeter.audio.adapter import AudioAdapter
from tqdm import tqdm
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def is_wav_file(filename):
return filename.endswith('.wav')
def is_audio_file(filename):
AUDIO_EXTENSIONS = ['.wav', '.mp3', '.wma', 'm4b']
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS)
def _get_paths_from_images(path, qualifier=is_image_file):
"""get image path list from image folder"""
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
if qualifier(fname) and 'ref.jpg' not in fname:
img_path = os.path.join(dirpath, fname)
images.append(img_path)
if not images:
print("Warning: {:s} has no valid image file".format(path))
return images
def _get_paths_from_lmdb(dataroot):
"""get image path list from lmdb meta info"""
meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
paths = meta_info['keys']
sizes = meta_info['resolution']
if len(sizes) == 1:
sizes = sizes * len(paths)
return paths, sizes
def find_audio_files(dataroot, include_nonwav=False):
if include_nonwav:
return find_files_of_type(None, dataroot, qualifier=is_audio_file)[0]
else:
return find_files_of_type(None, dataroot, qualifier=is_wav_file)[0]
def find_files_of_type(data_type, dataroot, weights=[], qualifier=is_image_file):
if isinstance(dataroot, list):
paths = []
for i in range(len(dataroot)):
r = dataroot[i]
extends = 1
# Weights have the effect of repeatedly adding the paths from the given root to the final product.
if weights:
extends = weights[i]
for j in range(extends):
paths.extend(_get_paths_from_images(r, qualifier))
paths = sorted(paths)
sizes = len(paths)
else:
paths = sorted(_get_paths_from_images(dataroot, qualifier))
sizes = len(paths)
return paths, sizes
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--path')
parser.add_argument('--out')
args = parser.parse_args()
src_dir = args.path
out_file = args.out
output_sample_rate=22050
batch_size=16
audio_loader = AudioAdapter.default()
files = find_audio_files(src_dir, include_nonwav=True)
#separator = Separator('pretrained_models/2stems', input_sr=output_sample_rate)
separator = Separator('spleeter:2stems')
unacceptable_files = open(out_file, 'w')
for path in tqdm(files):
print(f"Processing {src_dir}")
spleeter_ld, sr = audio_loader.load(path, sample_rate=output_sample_rate)
sep = separator.separate(spleeter_ld)
vocals = sep['vocals']
bg = sep['accompaniment']
vmax = np.abs(vocals).mean()
bmax = np.abs(bg).mean()
# Only output to the "good" sample dir if the ratio of background noise to vocal noise is high enough.
ratio = vmax / (bmax+.0000001)
if ratio < 25: # These values were derived empirically
unacceptable_files.write(f'{path}\n')
unacceptable_files.flush()
unacceptable_files.close()
if __name__ == '__main__':
main()

View File

@ -10,6 +10,8 @@ from models.spleeter.separator import Separator
from scripts.audio.preparation.spleeter_dataset import SpleeterDataset
# Note: The Pytorch implementation of Spleeter is not working correctly. Fixing this would significantly
# speed up the script since we can separate out dataloading and do batch inference.
def main():
src_dir = 'F:\\split\\joe_rogan'
output_sample_rate=22050