Batch spleeter cleaner using GPU
This commit is contained in:
parent
73b930c0f6
commit
742f9b4010
137
codes/models/spleeter/estimator.py
Normal file
137
codes/models/spleeter/estimator.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch import istft
|
||||
|
||||
from .unet import UNet
|
||||
from .util import tf2pytorch
|
||||
|
||||
|
||||
def load_ckpt(model, ckpt):
|
||||
state_dict = model.state_dict()
|
||||
for k, v in ckpt.items():
|
||||
if k in state_dict:
|
||||
target_shape = state_dict[k].shape
|
||||
assert target_shape == v.shape
|
||||
state_dict.update({k: torch.from_numpy(v)})
|
||||
else:
|
||||
print('Ignore ', k)
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def pad_and_partition(tensor, T):
|
||||
"""
|
||||
pads zero and partition tensor into segments of length T
|
||||
|
||||
Args:
|
||||
tensor(Tensor): BxCxFxL
|
||||
|
||||
Returns:
|
||||
tensor of size (B*[L/T] x C x F x T)
|
||||
"""
|
||||
old_size = tensor.size(3)
|
||||
new_size = math.ceil(old_size/T) * T
|
||||
tensor = F.pad(tensor, [0, new_size - old_size])
|
||||
[b, c, t, f] = tensor.shape
|
||||
split = new_size // T
|
||||
return torch.cat(torch.split(tensor, T, dim=3), dim=0)
|
||||
|
||||
|
||||
class Estimator(nn.Module):
|
||||
def __init__(self, num_instrumments, checkpoint_path):
|
||||
super(Estimator, self).__init__()
|
||||
|
||||
# stft config
|
||||
self.F = 1024
|
||||
self.T = 512
|
||||
self.win_length = 4096
|
||||
self.hop_length = 1024
|
||||
self.win = torch.hann_window(self.win_length)
|
||||
|
||||
ckpts = tf2pytorch(checkpoint_path, num_instrumments)
|
||||
|
||||
# filter
|
||||
self.instruments = nn.ModuleList()
|
||||
for i in range(num_instrumments):
|
||||
print('Loading model for instrumment {}'.format(i))
|
||||
net = UNet(2)
|
||||
ckpt = ckpts[i]
|
||||
net = load_ckpt(net, ckpt)
|
||||
net.eval() # change mode to eval
|
||||
self.instruments.append(net)
|
||||
|
||||
def compute_stft(self, wav):
|
||||
"""
|
||||
Computes stft feature from wav
|
||||
|
||||
Args:
|
||||
wav (Tensor): B x L
|
||||
"""
|
||||
|
||||
stft = torch.stft(
|
||||
wav, self.win_length, hop_length=self.hop_length, window=self.win.to(wav.device))
|
||||
|
||||
# only keep freqs smaller than self.F
|
||||
stft = stft[:, :self.F, :, :]
|
||||
real = stft[:, :, :, 0]
|
||||
im = stft[:, :, :, 1]
|
||||
mag = torch.sqrt(real ** 2 + im ** 2)
|
||||
|
||||
return stft, mag
|
||||
|
||||
def inverse_stft(self, stft):
|
||||
"""Inverses stft to wave form"""
|
||||
|
||||
pad = self.win_length // 2 + 1 - stft.size(1)
|
||||
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
|
||||
wav = istft(stft, self.win_length, hop_length=self.hop_length,
|
||||
window=self.win.to(stft.device))
|
||||
return wav.detach()
|
||||
|
||||
def separate(self, wav):
|
||||
"""
|
||||
Separates stereo wav into different tracks corresponding to different instruments
|
||||
|
||||
Args:
|
||||
wav (tensor): B x L
|
||||
"""
|
||||
|
||||
# stft - B X F x L x 2
|
||||
# stft_mag - B X F x L
|
||||
stft, stft_mag = self.compute_stft(wav)
|
||||
|
||||
L = stft.size(2)
|
||||
|
||||
stft_mag = stft_mag.unsqueeze(1).repeat(1,2,1,1) # B x 2 x F x T
|
||||
stft_mag = pad_and_partition(stft_mag, self.T) # B x 2 x F x T
|
||||
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
|
||||
|
||||
# compute instruments' mask
|
||||
masks = []
|
||||
for net in self.instruments:
|
||||
mask = net(stft_mag)
|
||||
masks.append(mask)
|
||||
|
||||
# compute denominator
|
||||
mask_sum = sum([m ** 2 for m in masks])
|
||||
mask_sum += 1e-10
|
||||
|
||||
wavs = []
|
||||
for mask in masks:
|
||||
mask = (mask ** 2 + 1e-10/2)/(mask_sum)
|
||||
mask = mask.transpose(2, 3) # B x 2 X F x T
|
||||
|
||||
mask = torch.cat(
|
||||
torch.split(mask, 1, dim=0), dim=3)
|
||||
|
||||
mask = mask[:,0,:,:L].unsqueeze(-1) # 2 x F x L x 1
|
||||
stft_masked = stft * mask
|
||||
wav_masked = self.inverse_stft(stft_masked)
|
||||
|
||||
wavs.append(wav_masked)
|
||||
|
||||
return wavs
|
32
codes/models/spleeter/separator.py
Normal file
32
codes/models/spleeter/separator.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.spleeter.estimator import Estimator
|
||||
|
||||
|
||||
class Separator:
|
||||
def __init__(self, model_path, input_sr=44100, device='cuda'):
|
||||
self.model = Estimator(2, model_path).to(device)
|
||||
self.device = device
|
||||
self.input_sr = input_sr
|
||||
|
||||
def separate(self, npwav, normalize=False):
|
||||
if not isinstance(npwav, torch.Tensor):
|
||||
assert len(npwav.shape) == 1
|
||||
wav = torch.tensor(npwav, device=self.device)
|
||||
wav = wav.view(1,-1)
|
||||
else:
|
||||
assert len(npwav.shape) == 2 # Input should be BxL
|
||||
wav = npwav.to(self.device)
|
||||
|
||||
if normalize:
|
||||
wav = wav / (wav.max() + 1e-8)
|
||||
|
||||
# Spleeter expects audio input to be 44.1kHz.
|
||||
wav = F.interpolate(wav.unsqueeze(1), mode='nearest', scale_factor=44100/self.input_sr).squeeze(1)
|
||||
res = self.model.separate(wav)
|
||||
res = [F.interpolate(r.unsqueeze(1), mode='nearest', scale_factor=self.input_sr/44100)[:,0] for r in res]
|
||||
return {
|
||||
'vocals': res[0].cpu().numpy(),
|
||||
'accompaniment': res[1].cpu().numpy()
|
||||
}
|
80
codes/models/spleeter/unet.py
Normal file
80
codes/models/spleeter/unet.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def down_block(in_filters, out_filters):
|
||||
return nn.Conv2d(in_filters, out_filters, kernel_size=5,
|
||||
stride=2, padding=2,
|
||||
), nn.Sequential(
|
||||
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01),
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
|
||||
|
||||
def up_block(in_filters, out_filters, dropout=False):
|
||||
layers = [
|
||||
nn.ConvTranspose2d(in_filters, out_filters, kernel_size=5,
|
||||
stride=2, padding=2, output_padding=1
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01)
|
||||
]
|
||||
if dropout:
|
||||
layers.append(nn.Dropout(0.5))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, in_channels=2):
|
||||
super(UNet, self).__init__()
|
||||
self.down1_conv, self.down1_act = down_block(in_channels, 16)
|
||||
self.down2_conv, self.down2_act = down_block(16, 32)
|
||||
self.down3_conv, self.down3_act = down_block(32, 64)
|
||||
self.down4_conv, self.down4_act = down_block(64, 128)
|
||||
self.down5_conv, self.down5_act = down_block(128, 256)
|
||||
self.down6_conv, self.down6_act = down_block(256, 512)
|
||||
|
||||
self.up1 = up_block(512, 256, dropout=True)
|
||||
self.up2 = up_block(512, 128, dropout=True)
|
||||
self.up3 = up_block(256, 64, dropout=True)
|
||||
self.up4 = up_block(128, 32)
|
||||
self.up5 = up_block(64, 16)
|
||||
self.up6 = up_block(32, 1)
|
||||
self.up7 = nn.Sequential(
|
||||
nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
d1_conv = self.down1_conv(x)
|
||||
d1 = self.down1_act(d1_conv)
|
||||
|
||||
d2_conv = self.down2_conv(d1)
|
||||
d2 = self.down2_act(d2_conv)
|
||||
|
||||
d3_conv = self.down3_conv(d2)
|
||||
d3 = self.down3_act(d3_conv)
|
||||
|
||||
d4_conv = self.down4_conv(d3)
|
||||
d4 = self.down4_act(d4_conv)
|
||||
|
||||
d5_conv = self.down5_conv(d4)
|
||||
d5 = self.down5_act(d5_conv)
|
||||
|
||||
d6_conv = self.down6_conv(d5)
|
||||
d6 = self.down6_act(d6_conv)
|
||||
|
||||
u1 = self.up1(d6)
|
||||
u2 = self.up2(torch.cat([d5_conv, u1], axis=1))
|
||||
u3 = self.up3(torch.cat([d4_conv, u2], axis=1))
|
||||
u4 = self.up4(torch.cat([d3_conv, u3], axis=1))
|
||||
u5 = self.up5(torch.cat([d2_conv, u4], axis=1))
|
||||
u6 = self.up6(torch.cat([d1_conv, u5], axis=1))
|
||||
u7 = self.up7(u6)
|
||||
return u7 * x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = UNet(14)
|
||||
print(net(torch.rand(1, 14, 20, 48)).shape)
|
91
codes/models/spleeter/util.py
Normal file
91
codes/models/spleeter/util.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .unet import UNet
|
||||
|
||||
|
||||
def tf2pytorch(checkpoint_path, num_instrumments):
|
||||
tf_vars = {}
|
||||
init_vars = tf.train.list_variables(checkpoint_path)
|
||||
# print(init_vars)
|
||||
for name, shape in init_vars:
|
||||
try:
|
||||
# print('Loading TF Weight {} with shape {}'.format(name, shape))
|
||||
data = tf.train.load_variable(checkpoint_path, name)
|
||||
tf_vars[name] = data
|
||||
except Exception as e:
|
||||
print('Load error')
|
||||
conv_idx = 0
|
||||
tconv_idx = 0
|
||||
bn_idx = 0
|
||||
outputs = []
|
||||
for i in range(num_instrumments):
|
||||
output = {}
|
||||
outputs.append(output)
|
||||
|
||||
for j in range(1,7):
|
||||
if conv_idx == 0:
|
||||
conv_suffix = ""
|
||||
else:
|
||||
conv_suffix = "_" + str(conv_idx)
|
||||
|
||||
if bn_idx == 0:
|
||||
bn_suffix = ""
|
||||
else:
|
||||
bn_suffix = "_" + str(bn_idx)
|
||||
|
||||
output['down{}_conv.weight'.format(j)] = np.transpose(
|
||||
tf_vars["conv2d{}/kernel".format(conv_suffix)], (3, 2, 0, 1))
|
||||
# print('conv dtype: ',output['down{}.0.weight'.format(j)].dtype)
|
||||
output['down{}_conv.bias'.format(
|
||||
j)] = tf_vars["conv2d{}/bias".format(conv_suffix)]
|
||||
|
||||
output['down{}_act.0.weight'.format(
|
||||
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
|
||||
output['down{}_act.0.bias'.format(
|
||||
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
|
||||
output['down{}_act.0.running_mean'.format(
|
||||
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
|
||||
output['down{}_act.0.running_var'.format(
|
||||
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
|
||||
|
||||
conv_idx += 1
|
||||
bn_idx += 1
|
||||
|
||||
# up blocks
|
||||
for j in range(1, 7):
|
||||
if tconv_idx == 0:
|
||||
tconv_suffix = ""
|
||||
else:
|
||||
tconv_suffix = "_" + str(tconv_idx)
|
||||
|
||||
if bn_idx == 0:
|
||||
bn_suffix = ""
|
||||
else:
|
||||
bn_suffix= "_" + str(bn_idx)
|
||||
|
||||
output['up{}.0.weight'.format(j)] = np.transpose(
|
||||
tf_vars["conv2d_transpose{}/kernel".format(tconv_suffix)], (3,2,0, 1))
|
||||
output['up{}.0.bias'.format(
|
||||
j)] = tf_vars["conv2d_transpose{}/bias".format(tconv_suffix)]
|
||||
output['up{}.2.weight'.format(
|
||||
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
|
||||
output['up{}.2.bias'.format(
|
||||
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
|
||||
output['up{}.2.running_mean'.format(
|
||||
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
|
||||
output['up{}.2.running_var'.format(
|
||||
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
|
||||
tconv_idx += 1
|
||||
bn_idx += 1
|
||||
|
||||
if conv_idx == 0:
|
||||
suffix = ""
|
||||
else:
|
||||
suffix = "_" + str(conv_idx)
|
||||
output['up7.0.weight'] = np.transpose(
|
||||
tf_vars['conv2d{}/kernel'.format(suffix)], (3, 2, 0, 1))
|
||||
output['up7.0.bias'] = tf_vars['conv2d{}/bias'.format(suffix)]
|
||||
conv_idx += 1
|
||||
|
||||
return outputs
|
37
codes/scripts/audio/preparation/spleeter_dataset.py
Normal file
37
codes/scripts/audio/preparation/spleeter_dataset.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from spleeter.audio.adapter import AudioAdapter
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from data.util import find_audio_files
|
||||
|
||||
|
||||
class SpleeterDataset(Dataset):
|
||||
def __init__(self, src_dir, sample_rate=22050, max_duration=20):
|
||||
self.files = find_audio_files(src_dir, include_nonwav=True)
|
||||
self.audio_loader = AudioAdapter.default()
|
||||
self.sample_rate = sample_rate
|
||||
self.max_duration = max_duration
|
||||
|
||||
def __getitem__(self, item):
|
||||
file = self.files[item]
|
||||
try:
|
||||
wave, sample_rate = self.audio_loader.load(file, sample_rate=self.sample_rate)
|
||||
assert sample_rate == self.sample_rate
|
||||
wave = wave[:,0] # strip off channels
|
||||
wave = torch.tensor(wave)
|
||||
except:
|
||||
wave = torch.zeros(self.sample_rate * self.max_duration)
|
||||
print(f"Error with {file}")
|
||||
original_duration = wave.shape[0]
|
||||
padding_needed = self.sample_rate * self.max_duration - original_duration
|
||||
if padding_needed > 0:
|
||||
wave = nn.functional.pad(wave, (0, padding_needed))
|
||||
return {
|
||||
'path': file,
|
||||
'wave': wave,
|
||||
'duration': original_duration,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files)
|
|
@ -0,0 +1,67 @@
|
|||
from scipy.io import wavfile
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.spleeter.separator import Separator
|
||||
from scripts.audio.preparation.spleeter_dataset import SpleeterDataset
|
||||
|
||||
|
||||
def main():
|
||||
src_dir = 'F:\\split\\podcast-dump0'
|
||||
output_dir = 'F:\\tmp\\out'
|
||||
output_dir_bg = 'F:\\tmp\\bg'
|
||||
output_sample_rate=22050
|
||||
batch_size=24
|
||||
|
||||
dl = DataLoader(SpleeterDataset(src_dir, output_sample_rate), batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
|
||||
separator = Separator('pretrained_models/2stems', input_sr=output_sample_rate)
|
||||
for e, batch in enumerate(tqdm(dl)):
|
||||
#if e < 406500:
|
||||
# continue
|
||||
waves = batch['wave']
|
||||
paths = batch['path']
|
||||
durations = batch['duration']
|
||||
|
||||
sep = separator.separate(waves)
|
||||
for j in range(sep['vocals'].shape[0]):
|
||||
vocals = sep['vocals'][j]
|
||||
bg = sep['accompaniment'][j]
|
||||
vmax = np.abs(vocals).mean()
|
||||
bmax = np.abs(bg).mean()
|
||||
|
||||
# Only output to the "good" sample dir if the ratio of background noise to vocal noise is high enough.
|
||||
ratio = vmax / (bmax+.0000001)
|
||||
if ratio >= 25: # These values were derived empirically
|
||||
od = output_dir
|
||||
out_sound = waves[j].cpu().numpy()
|
||||
elif ratio <= 1:
|
||||
od = output_dir_bg
|
||||
out_sound = bg
|
||||
else:
|
||||
continue
|
||||
|
||||
# Strip out channels.
|
||||
if len(out_sound.shape) > 1:
|
||||
out_sound = out_sound[:, 0] # Just use the first channel.
|
||||
# Resize to true duration
|
||||
out_sound = out_sound[:durations[j]]
|
||||
|
||||
# Compile an output path.
|
||||
path = paths[j]
|
||||
reld = os.path.relpath(os.path.dirname(path), src_dir)
|
||||
os.makedirs(os.path.join(od, reld), exist_ok=True)
|
||||
relp = os.path.relpath(path, src_dir)
|
||||
output_path = os.path.join(od, relp)
|
||||
|
||||
wavfile.write(output_path, output_sample_rate, out_sound)
|
||||
|
||||
|
||||
# Uses torch spleeter to divide audio clips into one of two bins:
|
||||
# 1. Audio has little to no background noise, saved to "output_dir"
|
||||
# 2. Audio has a lot of background noise, bg noise split off and saved to "output_dir_bg"
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -25,8 +25,8 @@ if __name__ == '__main__':
|
|||
separator = Separator('spleeter:2stems')
|
||||
files = find_audio_files(src_dir, include_nonwav=True)
|
||||
for e, file in enumerate(tqdm(files)):
|
||||
if e < 406500:
|
||||
continue
|
||||
#if e < 406500:
|
||||
# continue
|
||||
file_basis = osp.relpath(file, src_dir)\
|
||||
.replace('/', '_')\
|
||||
.replace('\\', '_')\
|
||||
|
|
Loading…
Reference in New Issue
Block a user