forked from mrq/DL-Art-School
Batch spleeter cleaner using GPU
This commit is contained in:
parent
73b930c0f6
commit
742f9b4010
137
codes/models/spleeter/estimator.py
Normal file
137
codes/models/spleeter/estimator.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from torch import istft
|
||||||
|
|
||||||
|
from .unet import UNet
|
||||||
|
from .util import tf2pytorch
|
||||||
|
|
||||||
|
|
||||||
|
def load_ckpt(model, ckpt):
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
for k, v in ckpt.items():
|
||||||
|
if k in state_dict:
|
||||||
|
target_shape = state_dict[k].shape
|
||||||
|
assert target_shape == v.shape
|
||||||
|
state_dict.update({k: torch.from_numpy(v)})
|
||||||
|
else:
|
||||||
|
print('Ignore ', k)
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def pad_and_partition(tensor, T):
|
||||||
|
"""
|
||||||
|
pads zero and partition tensor into segments of length T
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor(Tensor): BxCxFxL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor of size (B*[L/T] x C x F x T)
|
||||||
|
"""
|
||||||
|
old_size = tensor.size(3)
|
||||||
|
new_size = math.ceil(old_size/T) * T
|
||||||
|
tensor = F.pad(tensor, [0, new_size - old_size])
|
||||||
|
[b, c, t, f] = tensor.shape
|
||||||
|
split = new_size // T
|
||||||
|
return torch.cat(torch.split(tensor, T, dim=3), dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class Estimator(nn.Module):
|
||||||
|
def __init__(self, num_instrumments, checkpoint_path):
|
||||||
|
super(Estimator, self).__init__()
|
||||||
|
|
||||||
|
# stft config
|
||||||
|
self.F = 1024
|
||||||
|
self.T = 512
|
||||||
|
self.win_length = 4096
|
||||||
|
self.hop_length = 1024
|
||||||
|
self.win = torch.hann_window(self.win_length)
|
||||||
|
|
||||||
|
ckpts = tf2pytorch(checkpoint_path, num_instrumments)
|
||||||
|
|
||||||
|
# filter
|
||||||
|
self.instruments = nn.ModuleList()
|
||||||
|
for i in range(num_instrumments):
|
||||||
|
print('Loading model for instrumment {}'.format(i))
|
||||||
|
net = UNet(2)
|
||||||
|
ckpt = ckpts[i]
|
||||||
|
net = load_ckpt(net, ckpt)
|
||||||
|
net.eval() # change mode to eval
|
||||||
|
self.instruments.append(net)
|
||||||
|
|
||||||
|
def compute_stft(self, wav):
|
||||||
|
"""
|
||||||
|
Computes stft feature from wav
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav (Tensor): B x L
|
||||||
|
"""
|
||||||
|
|
||||||
|
stft = torch.stft(
|
||||||
|
wav, self.win_length, hop_length=self.hop_length, window=self.win.to(wav.device))
|
||||||
|
|
||||||
|
# only keep freqs smaller than self.F
|
||||||
|
stft = stft[:, :self.F, :, :]
|
||||||
|
real = stft[:, :, :, 0]
|
||||||
|
im = stft[:, :, :, 1]
|
||||||
|
mag = torch.sqrt(real ** 2 + im ** 2)
|
||||||
|
|
||||||
|
return stft, mag
|
||||||
|
|
||||||
|
def inverse_stft(self, stft):
|
||||||
|
"""Inverses stft to wave form"""
|
||||||
|
|
||||||
|
pad = self.win_length // 2 + 1 - stft.size(1)
|
||||||
|
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
|
||||||
|
wav = istft(stft, self.win_length, hop_length=self.hop_length,
|
||||||
|
window=self.win.to(stft.device))
|
||||||
|
return wav.detach()
|
||||||
|
|
||||||
|
def separate(self, wav):
|
||||||
|
"""
|
||||||
|
Separates stereo wav into different tracks corresponding to different instruments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav (tensor): B x L
|
||||||
|
"""
|
||||||
|
|
||||||
|
# stft - B X F x L x 2
|
||||||
|
# stft_mag - B X F x L
|
||||||
|
stft, stft_mag = self.compute_stft(wav)
|
||||||
|
|
||||||
|
L = stft.size(2)
|
||||||
|
|
||||||
|
stft_mag = stft_mag.unsqueeze(1).repeat(1,2,1,1) # B x 2 x F x T
|
||||||
|
stft_mag = pad_and_partition(stft_mag, self.T) # B x 2 x F x T
|
||||||
|
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
|
||||||
|
|
||||||
|
# compute instruments' mask
|
||||||
|
masks = []
|
||||||
|
for net in self.instruments:
|
||||||
|
mask = net(stft_mag)
|
||||||
|
masks.append(mask)
|
||||||
|
|
||||||
|
# compute denominator
|
||||||
|
mask_sum = sum([m ** 2 for m in masks])
|
||||||
|
mask_sum += 1e-10
|
||||||
|
|
||||||
|
wavs = []
|
||||||
|
for mask in masks:
|
||||||
|
mask = (mask ** 2 + 1e-10/2)/(mask_sum)
|
||||||
|
mask = mask.transpose(2, 3) # B x 2 X F x T
|
||||||
|
|
||||||
|
mask = torch.cat(
|
||||||
|
torch.split(mask, 1, dim=0), dim=3)
|
||||||
|
|
||||||
|
mask = mask[:,0,:,:L].unsqueeze(-1) # 2 x F x L x 1
|
||||||
|
stft_masked = stft * mask
|
||||||
|
wav_masked = self.inverse_stft(stft_masked)
|
||||||
|
|
||||||
|
wavs.append(wav_masked)
|
||||||
|
|
||||||
|
return wavs
|
32
codes/models/spleeter/separator.py
Normal file
32
codes/models/spleeter/separator.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from models.spleeter.estimator import Estimator
|
||||||
|
|
||||||
|
|
||||||
|
class Separator:
|
||||||
|
def __init__(self, model_path, input_sr=44100, device='cuda'):
|
||||||
|
self.model = Estimator(2, model_path).to(device)
|
||||||
|
self.device = device
|
||||||
|
self.input_sr = input_sr
|
||||||
|
|
||||||
|
def separate(self, npwav, normalize=False):
|
||||||
|
if not isinstance(npwav, torch.Tensor):
|
||||||
|
assert len(npwav.shape) == 1
|
||||||
|
wav = torch.tensor(npwav, device=self.device)
|
||||||
|
wav = wav.view(1,-1)
|
||||||
|
else:
|
||||||
|
assert len(npwav.shape) == 2 # Input should be BxL
|
||||||
|
wav = npwav.to(self.device)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
wav = wav / (wav.max() + 1e-8)
|
||||||
|
|
||||||
|
# Spleeter expects audio input to be 44.1kHz.
|
||||||
|
wav = F.interpolate(wav.unsqueeze(1), mode='nearest', scale_factor=44100/self.input_sr).squeeze(1)
|
||||||
|
res = self.model.separate(wav)
|
||||||
|
res = [F.interpolate(r.unsqueeze(1), mode='nearest', scale_factor=self.input_sr/44100)[:,0] for r in res]
|
||||||
|
return {
|
||||||
|
'vocals': res[0].cpu().numpy(),
|
||||||
|
'accompaniment': res[1].cpu().numpy()
|
||||||
|
}
|
80
codes/models/spleeter/unet.py
Normal file
80
codes/models/spleeter/unet.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def down_block(in_filters, out_filters):
|
||||||
|
return nn.Conv2d(in_filters, out_filters, kernel_size=5,
|
||||||
|
stride=2, padding=2,
|
||||||
|
), nn.Sequential(
|
||||||
|
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01),
|
||||||
|
nn.LeakyReLU(0.2)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def up_block(in_filters, out_filters, dropout=False):
|
||||||
|
layers = [
|
||||||
|
nn.ConvTranspose2d(in_filters, out_filters, kernel_size=5,
|
||||||
|
stride=2, padding=2, output_padding=1
|
||||||
|
),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01)
|
||||||
|
]
|
||||||
|
if dropout:
|
||||||
|
layers.append(nn.Dropout(0.5))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
class UNet(nn.Module):
|
||||||
|
def __init__(self, in_channels=2):
|
||||||
|
super(UNet, self).__init__()
|
||||||
|
self.down1_conv, self.down1_act = down_block(in_channels, 16)
|
||||||
|
self.down2_conv, self.down2_act = down_block(16, 32)
|
||||||
|
self.down3_conv, self.down3_act = down_block(32, 64)
|
||||||
|
self.down4_conv, self.down4_act = down_block(64, 128)
|
||||||
|
self.down5_conv, self.down5_act = down_block(128, 256)
|
||||||
|
self.down6_conv, self.down6_act = down_block(256, 512)
|
||||||
|
|
||||||
|
self.up1 = up_block(512, 256, dropout=True)
|
||||||
|
self.up2 = up_block(512, 128, dropout=True)
|
||||||
|
self.up3 = up_block(256, 64, dropout=True)
|
||||||
|
self.up4 = up_block(128, 32)
|
||||||
|
self.up5 = up_block(64, 16)
|
||||||
|
self.up6 = up_block(32, 1)
|
||||||
|
self.up7 = nn.Sequential(
|
||||||
|
nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
d1_conv = self.down1_conv(x)
|
||||||
|
d1 = self.down1_act(d1_conv)
|
||||||
|
|
||||||
|
d2_conv = self.down2_conv(d1)
|
||||||
|
d2 = self.down2_act(d2_conv)
|
||||||
|
|
||||||
|
d3_conv = self.down3_conv(d2)
|
||||||
|
d3 = self.down3_act(d3_conv)
|
||||||
|
|
||||||
|
d4_conv = self.down4_conv(d3)
|
||||||
|
d4 = self.down4_act(d4_conv)
|
||||||
|
|
||||||
|
d5_conv = self.down5_conv(d4)
|
||||||
|
d5 = self.down5_act(d5_conv)
|
||||||
|
|
||||||
|
d6_conv = self.down6_conv(d5)
|
||||||
|
d6 = self.down6_act(d6_conv)
|
||||||
|
|
||||||
|
u1 = self.up1(d6)
|
||||||
|
u2 = self.up2(torch.cat([d5_conv, u1], axis=1))
|
||||||
|
u3 = self.up3(torch.cat([d4_conv, u2], axis=1))
|
||||||
|
u4 = self.up4(torch.cat([d3_conv, u3], axis=1))
|
||||||
|
u5 = self.up5(torch.cat([d2_conv, u4], axis=1))
|
||||||
|
u6 = self.up6(torch.cat([d1_conv, u5], axis=1))
|
||||||
|
u7 = self.up7(u6)
|
||||||
|
return u7 * x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
net = UNet(14)
|
||||||
|
print(net(torch.rand(1, 14, 20, 48)).shape)
|
91
codes/models/spleeter/util.py
Normal file
91
codes/models/spleeter/util.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from .unet import UNet
|
||||||
|
|
||||||
|
|
||||||
|
def tf2pytorch(checkpoint_path, num_instrumments):
|
||||||
|
tf_vars = {}
|
||||||
|
init_vars = tf.train.list_variables(checkpoint_path)
|
||||||
|
# print(init_vars)
|
||||||
|
for name, shape in init_vars:
|
||||||
|
try:
|
||||||
|
# print('Loading TF Weight {} with shape {}'.format(name, shape))
|
||||||
|
data = tf.train.load_variable(checkpoint_path, name)
|
||||||
|
tf_vars[name] = data
|
||||||
|
except Exception as e:
|
||||||
|
print('Load error')
|
||||||
|
conv_idx = 0
|
||||||
|
tconv_idx = 0
|
||||||
|
bn_idx = 0
|
||||||
|
outputs = []
|
||||||
|
for i in range(num_instrumments):
|
||||||
|
output = {}
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
for j in range(1,7):
|
||||||
|
if conv_idx == 0:
|
||||||
|
conv_suffix = ""
|
||||||
|
else:
|
||||||
|
conv_suffix = "_" + str(conv_idx)
|
||||||
|
|
||||||
|
if bn_idx == 0:
|
||||||
|
bn_suffix = ""
|
||||||
|
else:
|
||||||
|
bn_suffix = "_" + str(bn_idx)
|
||||||
|
|
||||||
|
output['down{}_conv.weight'.format(j)] = np.transpose(
|
||||||
|
tf_vars["conv2d{}/kernel".format(conv_suffix)], (3, 2, 0, 1))
|
||||||
|
# print('conv dtype: ',output['down{}.0.weight'.format(j)].dtype)
|
||||||
|
output['down{}_conv.bias'.format(
|
||||||
|
j)] = tf_vars["conv2d{}/bias".format(conv_suffix)]
|
||||||
|
|
||||||
|
output['down{}_act.0.weight'.format(
|
||||||
|
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
|
||||||
|
output['down{}_act.0.bias'.format(
|
||||||
|
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
|
||||||
|
output['down{}_act.0.running_mean'.format(
|
||||||
|
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
|
||||||
|
output['down{}_act.0.running_var'.format(
|
||||||
|
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
|
||||||
|
|
||||||
|
conv_idx += 1
|
||||||
|
bn_idx += 1
|
||||||
|
|
||||||
|
# up blocks
|
||||||
|
for j in range(1, 7):
|
||||||
|
if tconv_idx == 0:
|
||||||
|
tconv_suffix = ""
|
||||||
|
else:
|
||||||
|
tconv_suffix = "_" + str(tconv_idx)
|
||||||
|
|
||||||
|
if bn_idx == 0:
|
||||||
|
bn_suffix = ""
|
||||||
|
else:
|
||||||
|
bn_suffix= "_" + str(bn_idx)
|
||||||
|
|
||||||
|
output['up{}.0.weight'.format(j)] = np.transpose(
|
||||||
|
tf_vars["conv2d_transpose{}/kernel".format(tconv_suffix)], (3,2,0, 1))
|
||||||
|
output['up{}.0.bias'.format(
|
||||||
|
j)] = tf_vars["conv2d_transpose{}/bias".format(tconv_suffix)]
|
||||||
|
output['up{}.2.weight'.format(
|
||||||
|
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
|
||||||
|
output['up{}.2.bias'.format(
|
||||||
|
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
|
||||||
|
output['up{}.2.running_mean'.format(
|
||||||
|
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
|
||||||
|
output['up{}.2.running_var'.format(
|
||||||
|
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
|
||||||
|
tconv_idx += 1
|
||||||
|
bn_idx += 1
|
||||||
|
|
||||||
|
if conv_idx == 0:
|
||||||
|
suffix = ""
|
||||||
|
else:
|
||||||
|
suffix = "_" + str(conv_idx)
|
||||||
|
output['up7.0.weight'] = np.transpose(
|
||||||
|
tf_vars['conv2d{}/kernel'.format(suffix)], (3, 2, 0, 1))
|
||||||
|
output['up7.0.bias'] = tf_vars['conv2d{}/bias'.format(suffix)]
|
||||||
|
conv_idx += 1
|
||||||
|
|
||||||
|
return outputs
|
37
codes/scripts/audio/preparation/spleeter_dataset.py
Normal file
37
codes/scripts/audio/preparation/spleeter_dataset.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from spleeter.audio.adapter import AudioAdapter
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from data.util import find_audio_files
|
||||||
|
|
||||||
|
|
||||||
|
class SpleeterDataset(Dataset):
|
||||||
|
def __init__(self, src_dir, sample_rate=22050, max_duration=20):
|
||||||
|
self.files = find_audio_files(src_dir, include_nonwav=True)
|
||||||
|
self.audio_loader = AudioAdapter.default()
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.max_duration = max_duration
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
file = self.files[item]
|
||||||
|
try:
|
||||||
|
wave, sample_rate = self.audio_loader.load(file, sample_rate=self.sample_rate)
|
||||||
|
assert sample_rate == self.sample_rate
|
||||||
|
wave = wave[:,0] # strip off channels
|
||||||
|
wave = torch.tensor(wave)
|
||||||
|
except:
|
||||||
|
wave = torch.zeros(self.sample_rate * self.max_duration)
|
||||||
|
print(f"Error with {file}")
|
||||||
|
original_duration = wave.shape[0]
|
||||||
|
padding_needed = self.sample_rate * self.max_duration - original_duration
|
||||||
|
if padding_needed > 0:
|
||||||
|
wave = nn.functional.pad(wave, (0, padding_needed))
|
||||||
|
return {
|
||||||
|
'path': file,
|
||||||
|
'wave': wave,
|
||||||
|
'duration': original_duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.files)
|
|
@ -0,0 +1,67 @@
|
||||||
|
from scipy.io import wavfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.io import wavfile
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from models.spleeter.separator import Separator
|
||||||
|
from scripts.audio.preparation.spleeter_dataset import SpleeterDataset
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
src_dir = 'F:\\split\\podcast-dump0'
|
||||||
|
output_dir = 'F:\\tmp\\out'
|
||||||
|
output_dir_bg = 'F:\\tmp\\bg'
|
||||||
|
output_sample_rate=22050
|
||||||
|
batch_size=24
|
||||||
|
|
||||||
|
dl = DataLoader(SpleeterDataset(src_dir, output_sample_rate), batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
|
||||||
|
separator = Separator('pretrained_models/2stems', input_sr=output_sample_rate)
|
||||||
|
for e, batch in enumerate(tqdm(dl)):
|
||||||
|
#if e < 406500:
|
||||||
|
# continue
|
||||||
|
waves = batch['wave']
|
||||||
|
paths = batch['path']
|
||||||
|
durations = batch['duration']
|
||||||
|
|
||||||
|
sep = separator.separate(waves)
|
||||||
|
for j in range(sep['vocals'].shape[0]):
|
||||||
|
vocals = sep['vocals'][j]
|
||||||
|
bg = sep['accompaniment'][j]
|
||||||
|
vmax = np.abs(vocals).mean()
|
||||||
|
bmax = np.abs(bg).mean()
|
||||||
|
|
||||||
|
# Only output to the "good" sample dir if the ratio of background noise to vocal noise is high enough.
|
||||||
|
ratio = vmax / (bmax+.0000001)
|
||||||
|
if ratio >= 25: # These values were derived empirically
|
||||||
|
od = output_dir
|
||||||
|
out_sound = waves[j].cpu().numpy()
|
||||||
|
elif ratio <= 1:
|
||||||
|
od = output_dir_bg
|
||||||
|
out_sound = bg
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Strip out channels.
|
||||||
|
if len(out_sound.shape) > 1:
|
||||||
|
out_sound = out_sound[:, 0] # Just use the first channel.
|
||||||
|
# Resize to true duration
|
||||||
|
out_sound = out_sound[:durations[j]]
|
||||||
|
|
||||||
|
# Compile an output path.
|
||||||
|
path = paths[j]
|
||||||
|
reld = os.path.relpath(os.path.dirname(path), src_dir)
|
||||||
|
os.makedirs(os.path.join(od, reld), exist_ok=True)
|
||||||
|
relp = os.path.relpath(path, src_dir)
|
||||||
|
output_path = os.path.join(od, relp)
|
||||||
|
|
||||||
|
wavfile.write(output_path, output_sample_rate, out_sound)
|
||||||
|
|
||||||
|
|
||||||
|
# Uses torch spleeter to divide audio clips into one of two bins:
|
||||||
|
# 1. Audio has little to no background noise, saved to "output_dir"
|
||||||
|
# 2. Audio has a lot of background noise, bg noise split off and saved to "output_dir_bg"
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -25,8 +25,8 @@ if __name__ == '__main__':
|
||||||
separator = Separator('spleeter:2stems')
|
separator = Separator('spleeter:2stems')
|
||||||
files = find_audio_files(src_dir, include_nonwav=True)
|
files = find_audio_files(src_dir, include_nonwav=True)
|
||||||
for e, file in enumerate(tqdm(files)):
|
for e, file in enumerate(tqdm(files)):
|
||||||
if e < 406500:
|
#if e < 406500:
|
||||||
continue
|
# continue
|
||||||
file_basis = osp.relpath(file, src_dir)\
|
file_basis = osp.relpath(file, src_dir)\
|
||||||
.replace('/', '_')\
|
.replace('/', '_')\
|
||||||
.replace('\\', '_')\
|
.replace('\\', '_')\
|
||||||
|
|
Loading…
Reference in New Issue
Block a user