Batch spleeter cleaner using GPU

This commit is contained in:
James Betker 2021-09-09 23:13:40 -06:00
parent 73b930c0f6
commit 742f9b4010
7 changed files with 446 additions and 2 deletions

View File

@ -0,0 +1,137 @@
import math
import torch
import torch.nn.functional as F
from torch import nn
from torch import istft
from .unet import UNet
from .util import tf2pytorch
def load_ckpt(model, ckpt):
state_dict = model.state_dict()
for k, v in ckpt.items():
if k in state_dict:
target_shape = state_dict[k].shape
assert target_shape == v.shape
state_dict.update({k: torch.from_numpy(v)})
else:
print('Ignore ', k)
model.load_state_dict(state_dict)
return model
def pad_and_partition(tensor, T):
"""
pads zero and partition tensor into segments of length T
Args:
tensor(Tensor): BxCxFxL
Returns:
tensor of size (B*[L/T] x C x F x T)
"""
old_size = tensor.size(3)
new_size = math.ceil(old_size/T) * T
tensor = F.pad(tensor, [0, new_size - old_size])
[b, c, t, f] = tensor.shape
split = new_size // T
return torch.cat(torch.split(tensor, T, dim=3), dim=0)
class Estimator(nn.Module):
def __init__(self, num_instrumments, checkpoint_path):
super(Estimator, self).__init__()
# stft config
self.F = 1024
self.T = 512
self.win_length = 4096
self.hop_length = 1024
self.win = torch.hann_window(self.win_length)
ckpts = tf2pytorch(checkpoint_path, num_instrumments)
# filter
self.instruments = nn.ModuleList()
for i in range(num_instrumments):
print('Loading model for instrumment {}'.format(i))
net = UNet(2)
ckpt = ckpts[i]
net = load_ckpt(net, ckpt)
net.eval() # change mode to eval
self.instruments.append(net)
def compute_stft(self, wav):
"""
Computes stft feature from wav
Args:
wav (Tensor): B x L
"""
stft = torch.stft(
wav, self.win_length, hop_length=self.hop_length, window=self.win.to(wav.device))
# only keep freqs smaller than self.F
stft = stft[:, :self.F, :, :]
real = stft[:, :, :, 0]
im = stft[:, :, :, 1]
mag = torch.sqrt(real ** 2 + im ** 2)
return stft, mag
def inverse_stft(self, stft):
"""Inverses stft to wave form"""
pad = self.win_length // 2 + 1 - stft.size(1)
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
wav = istft(stft, self.win_length, hop_length=self.hop_length,
window=self.win.to(stft.device))
return wav.detach()
def separate(self, wav):
"""
Separates stereo wav into different tracks corresponding to different instruments
Args:
wav (tensor): B x L
"""
# stft - B X F x L x 2
# stft_mag - B X F x L
stft, stft_mag = self.compute_stft(wav)
L = stft.size(2)
stft_mag = stft_mag.unsqueeze(1).repeat(1,2,1,1) # B x 2 x F x T
stft_mag = pad_and_partition(stft_mag, self.T) # B x 2 x F x T
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
# compute instruments' mask
masks = []
for net in self.instruments:
mask = net(stft_mag)
masks.append(mask)
# compute denominator
mask_sum = sum([m ** 2 for m in masks])
mask_sum += 1e-10
wavs = []
for mask in masks:
mask = (mask ** 2 + 1e-10/2)/(mask_sum)
mask = mask.transpose(2, 3) # B x 2 X F x T
mask = torch.cat(
torch.split(mask, 1, dim=0), dim=3)
mask = mask[:,0,:,:L].unsqueeze(-1) # 2 x F x L x 1
stft_masked = stft * mask
wav_masked = self.inverse_stft(stft_masked)
wavs.append(wav_masked)
return wavs

View File

@ -0,0 +1,32 @@
import torch
import torch.nn.functional as F
from models.spleeter.estimator import Estimator
class Separator:
def __init__(self, model_path, input_sr=44100, device='cuda'):
self.model = Estimator(2, model_path).to(device)
self.device = device
self.input_sr = input_sr
def separate(self, npwav, normalize=False):
if not isinstance(npwav, torch.Tensor):
assert len(npwav.shape) == 1
wav = torch.tensor(npwav, device=self.device)
wav = wav.view(1,-1)
else:
assert len(npwav.shape) == 2 # Input should be BxL
wav = npwav.to(self.device)
if normalize:
wav = wav / (wav.max() + 1e-8)
# Spleeter expects audio input to be 44.1kHz.
wav = F.interpolate(wav.unsqueeze(1), mode='nearest', scale_factor=44100/self.input_sr).squeeze(1)
res = self.model.separate(wav)
res = [F.interpolate(r.unsqueeze(1), mode='nearest', scale_factor=self.input_sr/44100)[:,0] for r in res]
return {
'vocals': res[0].cpu().numpy(),
'accompaniment': res[1].cpu().numpy()
}

View File

@ -0,0 +1,80 @@
import torch
from torch import nn
def down_block(in_filters, out_filters):
return nn.Conv2d(in_filters, out_filters, kernel_size=5,
stride=2, padding=2,
), nn.Sequential(
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01),
nn.LeakyReLU(0.2)
)
def up_block(in_filters, out_filters, dropout=False):
layers = [
nn.ConvTranspose2d(in_filters, out_filters, kernel_size=5,
stride=2, padding=2, output_padding=1
),
nn.ReLU(),
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01)
]
if dropout:
layers.append(nn.Dropout(0.5))
return nn.Sequential(*layers)
class UNet(nn.Module):
def __init__(self, in_channels=2):
super(UNet, self).__init__()
self.down1_conv, self.down1_act = down_block(in_channels, 16)
self.down2_conv, self.down2_act = down_block(16, 32)
self.down3_conv, self.down3_act = down_block(32, 64)
self.down4_conv, self.down4_act = down_block(64, 128)
self.down5_conv, self.down5_act = down_block(128, 256)
self.down6_conv, self.down6_act = down_block(256, 512)
self.up1 = up_block(512, 256, dropout=True)
self.up2 = up_block(512, 128, dropout=True)
self.up3 = up_block(256, 64, dropout=True)
self.up4 = up_block(128, 32)
self.up5 = up_block(64, 16)
self.up6 = up_block(32, 1)
self.up7 = nn.Sequential(
nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3),
nn.Sigmoid()
)
def forward(self, x):
d1_conv = self.down1_conv(x)
d1 = self.down1_act(d1_conv)
d2_conv = self.down2_conv(d1)
d2 = self.down2_act(d2_conv)
d3_conv = self.down3_conv(d2)
d3 = self.down3_act(d3_conv)
d4_conv = self.down4_conv(d3)
d4 = self.down4_act(d4_conv)
d5_conv = self.down5_conv(d4)
d5 = self.down5_act(d5_conv)
d6_conv = self.down6_conv(d5)
d6 = self.down6_act(d6_conv)
u1 = self.up1(d6)
u2 = self.up2(torch.cat([d5_conv, u1], axis=1))
u3 = self.up3(torch.cat([d4_conv, u2], axis=1))
u4 = self.up4(torch.cat([d3_conv, u3], axis=1))
u5 = self.up5(torch.cat([d2_conv, u4], axis=1))
u6 = self.up6(torch.cat([d1_conv, u5], axis=1))
u7 = self.up7(u6)
return u7 * x
if __name__ == '__main__':
net = UNet(14)
print(net(torch.rand(1, 14, 20, 48)).shape)

View File

@ -0,0 +1,91 @@
import numpy as np
import tensorflow as tf
from .unet import UNet
def tf2pytorch(checkpoint_path, num_instrumments):
tf_vars = {}
init_vars = tf.train.list_variables(checkpoint_path)
# print(init_vars)
for name, shape in init_vars:
try:
# print('Loading TF Weight {} with shape {}'.format(name, shape))
data = tf.train.load_variable(checkpoint_path, name)
tf_vars[name] = data
except Exception as e:
print('Load error')
conv_idx = 0
tconv_idx = 0
bn_idx = 0
outputs = []
for i in range(num_instrumments):
output = {}
outputs.append(output)
for j in range(1,7):
if conv_idx == 0:
conv_suffix = ""
else:
conv_suffix = "_" + str(conv_idx)
if bn_idx == 0:
bn_suffix = ""
else:
bn_suffix = "_" + str(bn_idx)
output['down{}_conv.weight'.format(j)] = np.transpose(
tf_vars["conv2d{}/kernel".format(conv_suffix)], (3, 2, 0, 1))
# print('conv dtype: ',output['down{}.0.weight'.format(j)].dtype)
output['down{}_conv.bias'.format(
j)] = tf_vars["conv2d{}/bias".format(conv_suffix)]
output['down{}_act.0.weight'.format(
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
output['down{}_act.0.bias'.format(
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
output['down{}_act.0.running_mean'.format(
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
output['down{}_act.0.running_var'.format(
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
conv_idx += 1
bn_idx += 1
# up blocks
for j in range(1, 7):
if tconv_idx == 0:
tconv_suffix = ""
else:
tconv_suffix = "_" + str(tconv_idx)
if bn_idx == 0:
bn_suffix = ""
else:
bn_suffix= "_" + str(bn_idx)
output['up{}.0.weight'.format(j)] = np.transpose(
tf_vars["conv2d_transpose{}/kernel".format(tconv_suffix)], (3,2,0, 1))
output['up{}.0.bias'.format(
j)] = tf_vars["conv2d_transpose{}/bias".format(tconv_suffix)]
output['up{}.2.weight'.format(
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
output['up{}.2.bias'.format(
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
output['up{}.2.running_mean'.format(
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
output['up{}.2.running_var'.format(
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
tconv_idx += 1
bn_idx += 1
if conv_idx == 0:
suffix = ""
else:
suffix = "_" + str(conv_idx)
output['up7.0.weight'] = np.transpose(
tf_vars['conv2d{}/kernel'.format(suffix)], (3, 2, 0, 1))
output['up7.0.bias'] = tf_vars['conv2d{}/bias'.format(suffix)]
conv_idx += 1
return outputs

View File

@ -0,0 +1,37 @@
import torch
import torch.nn as nn
from spleeter.audio.adapter import AudioAdapter
from torch.utils.data import Dataset
from data.util import find_audio_files
class SpleeterDataset(Dataset):
def __init__(self, src_dir, sample_rate=22050, max_duration=20):
self.files = find_audio_files(src_dir, include_nonwav=True)
self.audio_loader = AudioAdapter.default()
self.sample_rate = sample_rate
self.max_duration = max_duration
def __getitem__(self, item):
file = self.files[item]
try:
wave, sample_rate = self.audio_loader.load(file, sample_rate=self.sample_rate)
assert sample_rate == self.sample_rate
wave = wave[:,0] # strip off channels
wave = torch.tensor(wave)
except:
wave = torch.zeros(self.sample_rate * self.max_duration)
print(f"Error with {file}")
original_duration = wave.shape[0]
padding_needed = self.sample_rate * self.max_duration - original_duration
if padding_needed > 0:
wave = nn.functional.pad(wave, (0, padding_needed))
return {
'path': file,
'wave': wave,
'duration': original_duration,
}
def __len__(self):
return len(self.files)

View File

@ -0,0 +1,67 @@
from scipy.io import wavfile
import os
import numpy as np
from scipy.io import wavfile
from torch.utils.data import DataLoader
from tqdm import tqdm
from models.spleeter.separator import Separator
from scripts.audio.preparation.spleeter_dataset import SpleeterDataset
def main():
src_dir = 'F:\\split\\podcast-dump0'
output_dir = 'F:\\tmp\\out'
output_dir_bg = 'F:\\tmp\\bg'
output_sample_rate=22050
batch_size=24
dl = DataLoader(SpleeterDataset(src_dir, output_sample_rate), batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
separator = Separator('pretrained_models/2stems', input_sr=output_sample_rate)
for e, batch in enumerate(tqdm(dl)):
#if e < 406500:
# continue
waves = batch['wave']
paths = batch['path']
durations = batch['duration']
sep = separator.separate(waves)
for j in range(sep['vocals'].shape[0]):
vocals = sep['vocals'][j]
bg = sep['accompaniment'][j]
vmax = np.abs(vocals).mean()
bmax = np.abs(bg).mean()
# Only output to the "good" sample dir if the ratio of background noise to vocal noise is high enough.
ratio = vmax / (bmax+.0000001)
if ratio >= 25: # These values were derived empirically
od = output_dir
out_sound = waves[j].cpu().numpy()
elif ratio <= 1:
od = output_dir_bg
out_sound = bg
else:
continue
# Strip out channels.
if len(out_sound.shape) > 1:
out_sound = out_sound[:, 0] # Just use the first channel.
# Resize to true duration
out_sound = out_sound[:durations[j]]
# Compile an output path.
path = paths[j]
reld = os.path.relpath(os.path.dirname(path), src_dir)
os.makedirs(os.path.join(od, reld), exist_ok=True)
relp = os.path.relpath(path, src_dir)
output_path = os.path.join(od, relp)
wavfile.write(output_path, output_sample_rate, out_sound)
# Uses torch spleeter to divide audio clips into one of two bins:
# 1. Audio has little to no background noise, saved to "output_dir"
# 2. Audio has a lot of background noise, bg noise split off and saved to "output_dir_bg"
if __name__ == '__main__':
main()

View File

@ -25,8 +25,8 @@ if __name__ == '__main__':
separator = Separator('spleeter:2stems')
files = find_audio_files(src_dir, include_nonwav=True)
for e, file in enumerate(tqdm(files)):
if e < 406500:
continue
#if e < 406500:
# continue
file_basis = osp.relpath(file, src_dir)\
.replace('/', '_')\
.replace('\\', '_')\