DL-Art-School/codes/scripts/audio/random_mp3_splitter.py

76 lines
2.6 KiB
Python
Raw Normal View History

from scipy.io import wavfile
from spleeter.separator import Separator
from tqdm import tqdm
from data.util import find_audio_files
import os.path as osp
from spleeter.audio.adapter import AudioAdapter
import numpy as np
if __name__ == '__main__':
src_dir = 'O:\\podcast_dumps'
#src_dir = 'E:\\audio\\books'
output_dir = 'D:\\data\\audio\\podcasts-split'
output_dir_lq = 'D:\\data\\audio\\podcasts-split-with-bg'
output_dir_garbage = 'D:\\data\\audio\\podcasts-split-garbage'
#output_dir = 'E:\\audio\\books-clips'
clip_length = 5 # In seconds
sparsity = .1 # Only this proportion of the total clips are extracted as wavs.
output_sample_rate=22050
audio_loader = AudioAdapter.default()
separator = Separator('spleeter:2stems')
files = find_audio_files(src_dir, include_nonwav=True)
for e, file in enumerate(tqdm(files)):
if e < 575:
continue
file_basis = osp.relpath(file, src_dir)\
.replace('/', '_')\
.replace('\\', '_')\
.replace('.', '_')\
.replace(' ', '_')\
.replace('!', '_')\
.replace(',', '_')
if len(file_basis) > 100:
file_basis = file_basis[:100]
try:
wave, sample_rate = audio_loader.load(file, sample_rate=output_sample_rate)
except:
print(f"Error with {file}")
continue
#if len(wave.shape) < 2:
# continue
# Calculate how much data we need to extract for each clip.
clip_sz = sample_rate * clip_length
interval = int(sample_rate * (clip_length / sparsity))
i = 0
while (i+clip_sz) < wave.shape[0]:
clip = wave[i:i+clip_sz]
sep = separator.separate(clip)
vocals = sep['vocals']
bg = sep['accompaniment']
vmax = np.abs(vocals).mean()
bmax = np.abs(bg).mean()
# Only output to the "good" sample dir if the ratio of background noise to vocal noise is high enough.
ratio = vmax / bmax
if ratio >= 25: # These values were derived empirically
od = output_dir
os = clip
elif ratio >= 1:
od = output_dir_lq
os = vocals
else:
od = output_dir_garbage
os = vocals
# Strip out channels.
if len(os.shape) > 1:
os = os[:, 0] # Just use the first channel.
wavfile.write(osp.join(od, f'{e}_{file_basis}_{i}.wav'), output_sample_rate, os)
i = i + interval