forked from mrq/DL-Art-School
Restore spleeter_splitter
The mods don't help - in TF mode, everything is done on the GPU anyways. Something else is going to have to be done to fix this.
This commit is contained in:
parent
32ba496632
commit
c861054218
|
@ -1,14 +1,11 @@
|
||||||
import multiprocessing
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from spleeter.separator import Separator
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from scripts.audio.preparation.spleeter_utils.filter_noisy_clips_collector import invert_spectrogram_and_save
|
|
||||||
from scripts.audio.preparation.spleeter_utils.spleeter_dataset import SpleeterDataset
|
from scripts.audio.preparation.spleeter_utils.spleeter_dataset import SpleeterDataset
|
||||||
from scripts.audio.preparation.spleeter_utils.spleeter_separator_mod import Separator
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -21,24 +18,34 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
src_dir = args.path
|
src_dir = args.path
|
||||||
|
out_file = args.out
|
||||||
output_sample_rate=22050
|
output_sample_rate=22050
|
||||||
resume_file = args.resume
|
resume_file = args.resume
|
||||||
|
|
||||||
worker_queue = multiprocessing.Queue()
|
|
||||||
worker = multiprocessing.Process(target=invert_spectrogram_and_save, args=(args, worker_queue))
|
|
||||||
worker.start()
|
|
||||||
|
|
||||||
loader = DataLoader(SpleeterDataset(src_dir, batch_sz=16, sample_rate=output_sample_rate,
|
loader = DataLoader(SpleeterDataset(src_dir, batch_sz=16, sample_rate=output_sample_rate,
|
||||||
max_duration=10, partition=args.partition, partition_size=args.partition_size,
|
max_duration=10, partition=args.partition, partition_size=args.partition_size,
|
||||||
resume=resume_file), batch_size=1, num_workers=1)
|
resume=resume_file), batch_size=1, num_workers=1)
|
||||||
|
|
||||||
separator = Separator('spleeter:2stems', multiprocess=False)
|
separator = Separator('spleeter:2stems')
|
||||||
|
unacceptable_files = open(out_file, 'a')
|
||||||
for batch in tqdm(loader):
|
for batch in tqdm(loader):
|
||||||
audio, files, ends, stft = batch['audio'], batch['files'], batch['ends'], batch['stft']
|
audio, files, ends = batch['audio'], batch['files'], batch['ends']
|
||||||
sep = separator.separate_spectrogram(stft.squeeze(0).numpy())
|
sep = separator.separate(audio.squeeze(0).numpy())
|
||||||
worker_queue.put((sep['vocals'], sep['accompaniment'], audio.shape[1], files, ends))
|
vocals = sep['vocals']
|
||||||
worker_queue.put(None)
|
bg = sep['accompaniment']
|
||||||
worker.join()
|
start = 0
|
||||||
|
for path, end in zip(files, ends):
|
||||||
|
vmax = np.abs(vocals[start:end]).mean()
|
||||||
|
bmax = np.abs(bg[start:end]).mean()
|
||||||
|
start = end
|
||||||
|
|
||||||
|
# Only output to the "good" sample dir if the ratio of background noise to vocal noise is high enough.
|
||||||
|
ratio = vmax / (bmax+.0000001)
|
||||||
|
if ratio < 18: # These values were derived empirically
|
||||||
|
unacceptable_files.write(f'{path[0]}\n')
|
||||||
|
unacceptable_files.flush()
|
||||||
|
|
||||||
|
unacceptable_files.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
from scripts.audio.preparation.spleeter_utils.spleeter_separator_mod import Separator
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def invert_spectrogram_and_save(args, queue):
|
|
||||||
separator = Separator('spleeter:2stems', multiprocess=False, load_tf=False)
|
|
||||||
out_file = args.out
|
|
||||||
unacceptable_files = open(out_file, 'a')
|
|
||||||
|
|
||||||
while True:
|
|
||||||
combo = queue.get()
|
|
||||||
if combo is None:
|
|
||||||
break
|
|
||||||
vocals, bg, wavlen, files, ends = combo
|
|
||||||
vocals = separator.stft(vocals, inverse=True, length=wavlen)
|
|
||||||
bg = separator.stft(bg, inverse=True, length=wavlen)
|
|
||||||
start = 0
|
|
||||||
for path, end in zip(files, ends):
|
|
||||||
vmax = np.abs(vocals[start:end]).mean()
|
|
||||||
bmax = np.abs(bg[start:end]).mean()
|
|
||||||
start = end
|
|
||||||
|
|
||||||
# Only output to the "good" sample dir if the ratio of background noise to vocal noise is high enough.
|
|
||||||
ratio = vmax / (bmax+.0000001)
|
|
||||||
if ratio < 18: # These values were derived empirically
|
|
||||||
unacceptable_files.write(f'{path[0]}\n')
|
|
||||||
unacceptable_files.flush()
|
|
||||||
|
|
||||||
unacceptable_files.close()
|
|
|
@ -6,7 +6,6 @@ from spleeter.audio.adapter import AudioAdapter
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from data.util import find_audio_files
|
from data.util import find_audio_files
|
||||||
from scripts.audio.preparation.spleeter_utils.spleeter_separator_mod import Separator
|
|
||||||
|
|
||||||
|
|
||||||
class SpleeterDataset(Dataset):
|
class SpleeterDataset(Dataset):
|
||||||
|
@ -15,7 +14,6 @@ class SpleeterDataset(Dataset):
|
||||||
self.max_duration = max_duration
|
self.max_duration = max_duration
|
||||||
self.files = find_audio_files(src_dir, include_nonwav=True)
|
self.files = find_audio_files(src_dir, include_nonwav=True)
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.separator = Separator('spleeter:2stems', multiprocess=False, load_tf=False)
|
|
||||||
|
|
||||||
# Partition files if needed.
|
# Partition files if needed.
|
||||||
if partition_size is not None:
|
if partition_size is not None:
|
||||||
|
@ -45,25 +43,23 @@ class SpleeterDataset(Dataset):
|
||||||
if ind >= len(self.files):
|
if ind >= len(self.files):
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
#try:
|
||||||
wav, sr = self.loader.load(self.files[ind], sample_rate=self.sample_rate)
|
wav, sr = self.loader.load(self.files[ind], sample_rate=self.sample_rate)
|
||||||
assert sr == 22050
|
assert sr == 22050
|
||||||
# Get rid of all channels except one.
|
# Get rid of all channels except one.
|
||||||
if wav.shape[1] > 1:
|
if wav.shape[1] > 1:
|
||||||
wav = wav[:, 0]
|
wav = wav[:, 0]
|
||||||
|
|
||||||
if wavs is None:
|
if wavs is None:
|
||||||
wavs = wav
|
wavs = wav
|
||||||
else:
|
else:
|
||||||
wavs = np.concatenate([wavs, wav])
|
wavs = np.concatenate([wavs, wav])
|
||||||
ends.append(wavs.shape[0])
|
ends.append(wavs.shape[0])
|
||||||
files.append(self.files[ind])
|
files.append(self.files[ind])
|
||||||
except:
|
#except:
|
||||||
print(f'Error loading {self.files[ind]}')
|
# print(f'Error loading {self.files[ind]}')
|
||||||
stft = self.separator.stft(wavs)
|
|
||||||
return {
|
return {
|
||||||
'audio': wavs,
|
'audio': wavs,
|
||||||
'files': files,
|
'files': files,
|
||||||
'ends': ends,
|
'ends': ends
|
||||||
'stft': stft
|
}
|
||||||
}
|
|
|
@ -1,501 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf8
|
|
||||||
|
|
||||||
"""
|
|
||||||
Module that provides a class wrapper for source separation.
|
|
||||||
|
|
||||||
Modified to support directly feeding in spectrograms.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from spleeter.separator import Separator
|
|
||||||
>>> separator = Separator('spleeter:2stems')
|
|
||||||
>>> separator.separate(waveform, lambda instrument, data: ...)
|
|
||||||
>>> separator.separate_to_file(...)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
import atexit
|
|
||||||
import os
|
|
||||||
from multiprocessing import Pool
|
|
||||||
from os.path import basename, dirname, join, splitext
|
|
||||||
from typing import Dict, Generator, Optional
|
|
||||||
|
|
||||||
# pyright: reportMissingImports=false
|
|
||||||
# pylint: disable=import-error
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
from librosa.core import istft, stft
|
|
||||||
from scipy.signal.windows import hann
|
|
||||||
|
|
||||||
from spleeter.model.provider import ModelProvider
|
|
||||||
|
|
||||||
from spleeter import SpleeterError
|
|
||||||
from spleeter.audio import Codec, STFTBackend
|
|
||||||
from spleeter.audio.adapter import AudioAdapter
|
|
||||||
from spleeter.audio.convertor import to_stereo
|
|
||||||
from spleeter.model import EstimatorSpecBuilder, InputProviderFactory, model_fn
|
|
||||||
from spleeter.model.provider import ModelProvider
|
|
||||||
from spleeter.types import AudioDescriptor
|
|
||||||
from spleeter.utils.configuration import load_configuration
|
|
||||||
|
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
__email__ = "spleeter@deezer.com"
|
|
||||||
__author__ = "Deezer Research"
|
|
||||||
__license__ = "MIT License"
|
|
||||||
|
|
||||||
|
|
||||||
class DataGenerator(object):
|
|
||||||
"""
|
|
||||||
Generator object that store a sample and generate it once while called.
|
|
||||||
Used to feed a tensorflow estimator without knowing the whole data at
|
|
||||||
build time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
""" Default constructor. """
|
|
||||||
self._current_data = None
|
|
||||||
|
|
||||||
def update_data(self, data) -> None:
|
|
||||||
""" Replace internal data. """
|
|
||||||
self._current_data = data
|
|
||||||
|
|
||||||
def __call__(self) -> Generator:
|
|
||||||
""" Generation process. """
|
|
||||||
buffer = self._current_data
|
|
||||||
while buffer:
|
|
||||||
yield buffer
|
|
||||||
buffer = self._current_data
|
|
||||||
|
|
||||||
|
|
||||||
def create_estimator(params, MWF):
|
|
||||||
"""
|
|
||||||
Initialize tensorflow estimator that will perform separation
|
|
||||||
|
|
||||||
Params:
|
|
||||||
- params: a dictionary of parameters for building the model
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
a tensorflow estimator
|
|
||||||
"""
|
|
||||||
# Load model.
|
|
||||||
provider: ModelProvider = ModelProvider.default()
|
|
||||||
params["model_dir"] = provider.get(params["model_dir"])
|
|
||||||
params["MWF"] = MWF
|
|
||||||
# Setup config
|
|
||||||
session_config = tf.compat.v1.ConfigProto()
|
|
||||||
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
|
||||||
config = tf.estimator.RunConfig(session_config=session_config)
|
|
||||||
# Setup estimator
|
|
||||||
estimator = tf.estimator.Estimator(
|
|
||||||
model_fn=model_fn, model_dir=params["model_dir"], params=params, config=config
|
|
||||||
)
|
|
||||||
return estimator
|
|
||||||
|
|
||||||
|
|
||||||
class Separator(object):
|
|
||||||
""" A wrapper class for performing separation. """
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params_descriptor: str,
|
|
||||||
MWF: bool = False,
|
|
||||||
stft_backend: STFTBackend = STFTBackend.AUTO,
|
|
||||||
multiprocess: bool = True,
|
|
||||||
load_tf: bool = True
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Default constructor.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
params_descriptor (str):
|
|
||||||
Descriptor for TF params to be used.
|
|
||||||
MWF (bool):
|
|
||||||
(Optional) `True` if MWF should be used, `False` otherwise.
|
|
||||||
"""
|
|
||||||
self._params = load_configuration(params_descriptor)
|
|
||||||
self._sample_rate = self._params["sample_rate"]
|
|
||||||
self._MWF = MWF
|
|
||||||
if load_tf:
|
|
||||||
self._tf_graph = tf.Graph()
|
|
||||||
else:
|
|
||||||
self._tf_graph = None
|
|
||||||
self._prediction_generator = None
|
|
||||||
self._input_provider = None
|
|
||||||
self._builder = None
|
|
||||||
self._features = None
|
|
||||||
self._session = None
|
|
||||||
if multiprocess:
|
|
||||||
self._pool = Pool()
|
|
||||||
atexit.register(self._pool.close)
|
|
||||||
else:
|
|
||||||
self._pool = None
|
|
||||||
self._tasks = []
|
|
||||||
self._params["stft_backend"] = STFTBackend.resolve(stft_backend)
|
|
||||||
self._data_generator = DataGenerator()
|
|
||||||
|
|
||||||
def _get_prediction_generator(self) -> Generator:
|
|
||||||
"""
|
|
||||||
Lazy loading access method for internal prediction generator
|
|
||||||
returned by the predict method of a tensorflow estimator.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generator:
|
|
||||||
Generator of prediction.
|
|
||||||
"""
|
|
||||||
if self._prediction_generator is None:
|
|
||||||
estimator = create_estimator(self._params, self._MWF)
|
|
||||||
|
|
||||||
def get_dataset():
|
|
||||||
return tf.data.Dataset.from_generator(
|
|
||||||
self._data_generator,
|
|
||||||
output_types={"waveform": tf.float32, "audio_id": tf.string},
|
|
||||||
output_shapes={"waveform": (None, 2), "audio_id": ()},
|
|
||||||
)
|
|
||||||
|
|
||||||
self._prediction_generator = estimator.predict(
|
|
||||||
get_dataset, yield_single_examples=False
|
|
||||||
)
|
|
||||||
return self._prediction_generator
|
|
||||||
|
|
||||||
def join(self, timeout: int = 200) -> None:
|
|
||||||
"""
|
|
||||||
Wait for all pending tasks to be finished.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
timeout (int):
|
|
||||||
(Optional) task waiting timeout.
|
|
||||||
"""
|
|
||||||
while len(self._tasks) > 0:
|
|
||||||
task = self._tasks.pop()
|
|
||||||
task.get()
|
|
||||||
task.wait(timeout=timeout)
|
|
||||||
|
|
||||||
def stft(
|
|
||||||
self, data: np.ndarray, inverse: bool = False, length: Optional[int] = None
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Single entrypoint for both stft and istft. This computes stft and
|
|
||||||
istft with librosa on stereo data. The two channels are processed
|
|
||||||
separately and are concatenated together in the result. The
|
|
||||||
expected input formats are: (n_samples, 2) for stft and (T, F, 2)
|
|
||||||
for istft.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
data (numpy.array):
|
|
||||||
Array with either the waveform or the complex spectrogram
|
|
||||||
depending on the parameter inverse
|
|
||||||
inverse (bool):
|
|
||||||
(Optional) Should a stft or an istft be computed.
|
|
||||||
length (Optional[int]):
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
numpy.ndarray:
|
|
||||||
Stereo data as numpy array for the transform. The channels
|
|
||||||
are stored in the last dimension.
|
|
||||||
"""
|
|
||||||
assert not (inverse and length is None)
|
|
||||||
data = np.asfortranarray(data)
|
|
||||||
N = self._params["frame_length"]
|
|
||||||
H = self._params["frame_step"]
|
|
||||||
win = hann(N, sym=False)
|
|
||||||
fstft = istft if inverse else stft
|
|
||||||
win_len_arg = {"win_length": None, "length": None} if inverse else {"n_fft": N}
|
|
||||||
n_channels = data.shape[-1]
|
|
||||||
out = []
|
|
||||||
for c in range(n_channels):
|
|
||||||
d = (
|
|
||||||
np.concatenate((np.zeros((N,)), data[:, c], np.zeros((N,))))
|
|
||||||
if not inverse
|
|
||||||
else data[:, :, c].T
|
|
||||||
)
|
|
||||||
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
|
|
||||||
if inverse:
|
|
||||||
s = s[N : N + length]
|
|
||||||
s = np.expand_dims(s.T, 2 - inverse)
|
|
||||||
out.append(s)
|
|
||||||
if len(out) == 1:
|
|
||||||
return out[0]
|
|
||||||
return np.concatenate(out, axis=2 - inverse)
|
|
||||||
|
|
||||||
def _get_input_provider(self):
|
|
||||||
if self._input_provider is None:
|
|
||||||
self._input_provider = InputProviderFactory.get(self._params)
|
|
||||||
return self._input_provider
|
|
||||||
|
|
||||||
def _get_features(self):
|
|
||||||
if self._features is None:
|
|
||||||
provider = self._get_input_provider()
|
|
||||||
self._features = provider.get_input_dict_placeholders()
|
|
||||||
return self._features
|
|
||||||
|
|
||||||
def _get_builder(self):
|
|
||||||
if self._builder is None:
|
|
||||||
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
|
|
||||||
return self._builder
|
|
||||||
|
|
||||||
def _get_session(self):
|
|
||||||
if self._session is None:
|
|
||||||
saver = tf.compat.v1.train.Saver()
|
|
||||||
provider = ModelProvider.default()
|
|
||||||
model_directory: str = provider.get(self._params["model_dir"])
|
|
||||||
latest_checkpoint = tf.train.latest_checkpoint(model_directory)
|
|
||||||
self._session = tf.compat.v1.Session()
|
|
||||||
saver.restore(self._session, latest_checkpoint)
|
|
||||||
return self._session
|
|
||||||
|
|
||||||
def _separate_raw_spec(self, stft: np.ndarray, audio_descriptor: AudioDescriptor) -> Dict:
|
|
||||||
with self._tf_graph.as_default():
|
|
||||||
out = {}
|
|
||||||
features = self._get_features()
|
|
||||||
outputs = self._get_builder().outputs
|
|
||||||
if stft.shape[-1] == 1:
|
|
||||||
stft = np.concatenate([stft, stft], axis=-1)
|
|
||||||
elif stft.shape[-1] > 2:
|
|
||||||
stft = stft[:, :2]
|
|
||||||
sess = self._get_session()
|
|
||||||
outputs = sess.run(
|
|
||||||
outputs,
|
|
||||||
feed_dict=self._get_input_provider().get_feed_dict(
|
|
||||||
features, stft, audio_descriptor
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for inst in self._get_builder().instruments:
|
|
||||||
out[inst] = outputs[inst]
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _separate_librosa(
|
|
||||||
self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
Performs separation with librosa backend for STFT.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
waveform (numpy.ndarray):
|
|
||||||
Waveform to be separated (as a numpy array)
|
|
||||||
audio_descriptor (AudioDescriptor):
|
|
||||||
"""
|
|
||||||
with self._tf_graph.as_default():
|
|
||||||
out = {}
|
|
||||||
features = self._get_features()
|
|
||||||
# TODO: fix the logic, build sometimes return,
|
|
||||||
# sometimes set attribute.
|
|
||||||
outputs = self._get_builder().outputs
|
|
||||||
stft = self.stft(waveform)
|
|
||||||
if stft.shape[-1] == 1:
|
|
||||||
stft = np.concatenate([stft, stft], axis=-1)
|
|
||||||
elif stft.shape[-1] > 2:
|
|
||||||
stft = stft[:, :2]
|
|
||||||
sess = self._get_session()
|
|
||||||
outputs = sess.run(
|
|
||||||
outputs,
|
|
||||||
feed_dict=self._get_input_provider().get_feed_dict(
|
|
||||||
features, stft, audio_descriptor
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for inst in self._get_builder().instruments:
|
|
||||||
out[inst] = self.stft(
|
|
||||||
outputs[inst], inverse=True, length=waveform.shape[0]
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _separate_tensorflow(
|
|
||||||
self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
Performs source separation over the given waveform with tensorflow
|
|
||||||
backend.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
waveform (numpy.ndarray):
|
|
||||||
Waveform to be separated (as a numpy array)
|
|
||||||
audio_descriptor (AudioDescriptor):
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Separated waveforms.
|
|
||||||
"""
|
|
||||||
if not waveform.shape[-1] == 2:
|
|
||||||
waveform = to_stereo(waveform)
|
|
||||||
prediction_generator = self._get_prediction_generator()
|
|
||||||
# NOTE: update data in generator before performing separation.
|
|
||||||
self._data_generator.update_data(
|
|
||||||
{"waveform": waveform, "audio_id": np.array(audio_descriptor)}
|
|
||||||
)
|
|
||||||
# NOTE: perform separation.
|
|
||||||
prediction = next(prediction_generator)
|
|
||||||
prediction.pop("audio_id")
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
def separate(
|
|
||||||
self, waveform: np.ndarray, audio_descriptor: Optional[str] = ""
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Performs separation on a waveform.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
waveform (numpy.ndarray):
|
|
||||||
Waveform to be separated (as a numpy array)
|
|
||||||
audio_descriptor (str):
|
|
||||||
(Optional) string describing the waveform (e.g. filename).
|
|
||||||
"""
|
|
||||||
backend: str = self._params["stft_backend"]
|
|
||||||
if backend == STFTBackend.TENSORFLOW:
|
|
||||||
return self._separate_tensorflow(waveform, audio_descriptor)
|
|
||||||
elif backend == STFTBackend.LIBROSA:
|
|
||||||
return self._separate_librosa(waveform, audio_descriptor)
|
|
||||||
raise ValueError(f"Unsupported STFT backend {backend}")
|
|
||||||
|
|
||||||
def separate_spectrogram(
|
|
||||||
self, stft: np.ndarray, audio_descriptor: Optional[str] = ""
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Performs separation on a spectrogram.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
stft (numpy.ndarray):
|
|
||||||
Spectrogram to be separated (as a numpy array)
|
|
||||||
audio_descriptor (str):
|
|
||||||
(Optional) string describing the waveform (e.g. filename).
|
|
||||||
"""
|
|
||||||
return self._separate_raw_spec(stft, audio_descriptor)
|
|
||||||
|
|
||||||
def separate_to_file(
|
|
||||||
self,
|
|
||||||
audio_descriptor: AudioDescriptor,
|
|
||||||
destination: str,
|
|
||||||
audio_adapter: Optional[AudioAdapter] = None,
|
|
||||||
offset: int = 0,
|
|
||||||
duration: float = 600.0,
|
|
||||||
codec: Codec = Codec.WAV,
|
|
||||||
bitrate: str = "128k",
|
|
||||||
filename_format: str = "{filename}/{instrument}.{codec}",
|
|
||||||
synchronous: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Performs source separation and export result to file using
|
|
||||||
given audio adapter.
|
|
||||||
|
|
||||||
Filename format should be a Python formattable string that could
|
|
||||||
use following parameters :
|
|
||||||
|
|
||||||
- {instrument}
|
|
||||||
- {filename}
|
|
||||||
- {foldername}
|
|
||||||
- {codec}.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
audio_descriptor (AudioDescriptor):
|
|
||||||
Describe song to separate, used by audio adapter to
|
|
||||||
retrieve and load audio data, in case of file based
|
|
||||||
audio adapter, such descriptor would be a file path.
|
|
||||||
destination (str):
|
|
||||||
Target directory to write output to.
|
|
||||||
audio_adapter (Optional[AudioAdapter]):
|
|
||||||
(Optional) Audio adapter to use for I/O.
|
|
||||||
offset (int):
|
|
||||||
(Optional) Offset of loaded song.
|
|
||||||
duration (float):
|
|
||||||
(Optional) Duration of loaded song (default: 600s).
|
|
||||||
codec (Codec):
|
|
||||||
(Optional) Export codec.
|
|
||||||
bitrate (str):
|
|
||||||
(Optional) Export bitrate.
|
|
||||||
filename_format (str):
|
|
||||||
(Optional) Filename format.
|
|
||||||
synchronous (bool):
|
|
||||||
(Optional) True is should by synchronous.
|
|
||||||
"""
|
|
||||||
if audio_adapter is None:
|
|
||||||
audio_adapter = AudioAdapter.default()
|
|
||||||
waveform, _ = audio_adapter.load(
|
|
||||||
audio_descriptor,
|
|
||||||
offset=offset,
|
|
||||||
duration=duration,
|
|
||||||
sample_rate=self._sample_rate,
|
|
||||||
)
|
|
||||||
sources = self.separate(waveform, audio_descriptor)
|
|
||||||
self.save_to_file(
|
|
||||||
sources,
|
|
||||||
audio_descriptor,
|
|
||||||
destination,
|
|
||||||
filename_format,
|
|
||||||
codec,
|
|
||||||
audio_adapter,
|
|
||||||
bitrate,
|
|
||||||
synchronous,
|
|
||||||
)
|
|
||||||
|
|
||||||
def save_to_file(
|
|
||||||
self,
|
|
||||||
sources: Dict,
|
|
||||||
audio_descriptor: AudioDescriptor,
|
|
||||||
destination: str,
|
|
||||||
filename_format: str = "{filename}/{instrument}.{codec}",
|
|
||||||
codec: Codec = Codec.WAV,
|
|
||||||
audio_adapter: Optional[AudioAdapter] = None,
|
|
||||||
bitrate: str = "128k",
|
|
||||||
synchronous: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Export dictionary of sources to files.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
sources (Dict):
|
|
||||||
Dictionary of sources to be exported. The keys are the name
|
|
||||||
of the instruments, and the values are `N x 2` numpy arrays
|
|
||||||
containing the corresponding intrument waveform, as
|
|
||||||
returned by the separate method
|
|
||||||
audio_descriptor (AudioDescriptor):
|
|
||||||
Describe song to separate, used by audio adapter to
|
|
||||||
retrieve and load audio data, in case of file based audio
|
|
||||||
adapter, such descriptor would be a file path.
|
|
||||||
destination (str):
|
|
||||||
Target directory to write output to.
|
|
||||||
filename_format (str):
|
|
||||||
(Optional) Filename format.
|
|
||||||
codec (Codec):
|
|
||||||
(Optional) Export codec.
|
|
||||||
audio_adapter (Optional[AudioAdapter]):
|
|
||||||
(Optional) Audio adapter to use for I/O.
|
|
||||||
bitrate (str):
|
|
||||||
(Optional) Export bitrate.
|
|
||||||
synchronous (bool):
|
|
||||||
(Optional) True is should by synchronous.
|
|
||||||
"""
|
|
||||||
if audio_adapter is None:
|
|
||||||
audio_adapter = AudioAdapter.default()
|
|
||||||
foldername = basename(dirname(audio_descriptor))
|
|
||||||
filename = splitext(basename(audio_descriptor))[0]
|
|
||||||
generated = []
|
|
||||||
for instrument, data in sources.items():
|
|
||||||
path = join(
|
|
||||||
destination,
|
|
||||||
filename_format.format(
|
|
||||||
filename=filename,
|
|
||||||
instrument=instrument,
|
|
||||||
foldername=foldername,
|
|
||||||
codec=codec,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
directory = os.path.dirname(path)
|
|
||||||
if not os.path.exists(directory):
|
|
||||||
os.makedirs(directory)
|
|
||||||
if path in generated:
|
|
||||||
raise SpleeterError(
|
|
||||||
(
|
|
||||||
f"Separated source path conflict : {path},"
|
|
||||||
"please check your filename format"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
generated.append(path)
|
|
||||||
if self._pool:
|
|
||||||
task = self._pool.apply_async(
|
|
||||||
audio_adapter.save, (path, data, self._sample_rate, codec, bitrate)
|
|
||||||
)
|
|
||||||
self._tasks.append(task)
|
|
||||||
else:
|
|
||||||
audio_adapter.save(path, data, self._sample_rate, codec, bitrate)
|
|
||||||
if synchronous and self._pool:
|
|
||||||
self.join()
|
|
Loading…
Reference in New Issue
Block a user