DL-Art-School/codes/scripts/audio/preparation/spleeter_dataset.py
2021-10-07 21:28:00 -06:00

85 lines
2.7 KiB
Python

from typing import Optional
import torch
import torch.nn as nn
from scipy.signal.windows import hann
from spleeter.audio.adapter import AudioAdapter
from torch.utils.data import Dataset
import numpy as np
import librosa
from data.util import find_audio_files
def spleeter_stft(
data: np.ndarray, inverse: bool = False, length: Optional[int] = None
) -> np.ndarray:
"""
Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
separately and are concatenated together in the result. The
expected input formats are: (n_samples, 2) for stft and (T, F, 2)
for istft.
Parameters:
data (numpy.array):
Array with either the waveform or the complex spectrogram
depending on the parameter inverse
inverse (bool):
(Optional) Should a stft or an istft be computed.
length (Optional[int]):
Returns:
numpy.ndarray:
Stereo data as numpy array for the transform. The channels
are stored in the last dimension.
"""
assert not (inverse and length is None)
data = np.asfortranarray(data)
N = 4096
H = 1024
win = hann(N, sym=False)
fstft = librosa.core.istft if inverse else librosa.core.stft
win_len_arg = {"win_length": None, "length": None} if inverse else {"n_fft": N}
n_channels = data.shape[-1]
out = []
for c in range(n_channels):
d = (
np.concatenate((np.zeros((N,)), data[:, c], np.zeros((N,))))
if not inverse
else data[:, :, c].T
)
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
if inverse:
s = s[N: N + length]
s = np.expand_dims(s.T, 2 - inverse)
out.append(s)
if len(out) == 1:
return out[0]
return np.concatenate(out, axis=2 - inverse)
class SpleeterDataset(Dataset):
def __init__(self, src_dir, sample_rate=22050, max_duration=20, skip=0):
self.files = find_audio_files(src_dir, include_nonwav=True)
if skip > 0:
self.files = self.files[skip:]
self.audio_loader = AudioAdapter.default()
self.sample_rate = sample_rate
self.max_duration = max_duration
def __getitem__(self, item):
file = self.files[item]
wave, sample_rate = self.audio_loader.load(file, sample_rate=self.sample_rate)
assert sample_rate == self.sample_rate
stft = torch.tensor(spleeter_stft(wave))
# TODO: pad this up so it can be batched.
return {
'path': file,
'wave': wave,
'stft': stft,
#'duration': original_duration,
}
def __len__(self):
return len(self.files)