From 742f9b4010ce3068f0e587268944ceaeeae1cce5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 9 Sep 2021 23:13:40 -0600 Subject: [PATCH] Batch spleeter cleaner using GPU --- codes/models/spleeter/estimator.py | 137 ++++++++++++++++++ codes/models/spleeter/separator.py | 32 ++++ codes/models/spleeter/unet.py | 80 ++++++++++ codes/models/spleeter/util.py | 91 ++++++++++++ .../audio/preparation/spleeter_dataset.py | 37 +++++ .../spleeter_split_voice_and_background_2.py | 67 +++++++++ .../spleeter_split_voice_and_background.py | 4 +- 7 files changed, 446 insertions(+), 2 deletions(-) create mode 100644 codes/models/spleeter/estimator.py create mode 100644 codes/models/spleeter/separator.py create mode 100644 codes/models/spleeter/unet.py create mode 100644 codes/models/spleeter/util.py create mode 100644 codes/scripts/audio/preparation/spleeter_dataset.py create mode 100644 codes/scripts/audio/preparation/spleeter_split_voice_and_background_2.py diff --git a/codes/models/spleeter/estimator.py b/codes/models/spleeter/estimator.py new file mode 100644 index 00000000..207c4989 --- /dev/null +++ b/codes/models/spleeter/estimator.py @@ -0,0 +1,137 @@ +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch import istft + +from .unet import UNet +from .util import tf2pytorch + + +def load_ckpt(model, ckpt): + state_dict = model.state_dict() + for k, v in ckpt.items(): + if k in state_dict: + target_shape = state_dict[k].shape + assert target_shape == v.shape + state_dict.update({k: torch.from_numpy(v)}) + else: + print('Ignore ', k) + + model.load_state_dict(state_dict) + return model + + +def pad_and_partition(tensor, T): + """ + pads zero and partition tensor into segments of length T + + Args: + tensor(Tensor): BxCxFxL + + Returns: + tensor of size (B*[L/T] x C x F x T) + """ + old_size = tensor.size(3) + new_size = math.ceil(old_size/T) * T + tensor = F.pad(tensor, [0, new_size - old_size]) + [b, c, t, f] = tensor.shape + split = new_size // T + return torch.cat(torch.split(tensor, T, dim=3), dim=0) + + +class Estimator(nn.Module): + def __init__(self, num_instrumments, checkpoint_path): + super(Estimator, self).__init__() + + # stft config + self.F = 1024 + self.T = 512 + self.win_length = 4096 + self.hop_length = 1024 + self.win = torch.hann_window(self.win_length) + + ckpts = tf2pytorch(checkpoint_path, num_instrumments) + + # filter + self.instruments = nn.ModuleList() + for i in range(num_instrumments): + print('Loading model for instrumment {}'.format(i)) + net = UNet(2) + ckpt = ckpts[i] + net = load_ckpt(net, ckpt) + net.eval() # change mode to eval + self.instruments.append(net) + + def compute_stft(self, wav): + """ + Computes stft feature from wav + + Args: + wav (Tensor): B x L + """ + + stft = torch.stft( + wav, self.win_length, hop_length=self.hop_length, window=self.win.to(wav.device)) + + # only keep freqs smaller than self.F + stft = stft[:, :self.F, :, :] + real = stft[:, :, :, 0] + im = stft[:, :, :, 1] + mag = torch.sqrt(real ** 2 + im ** 2) + + return stft, mag + + def inverse_stft(self, stft): + """Inverses stft to wave form""" + + pad = self.win_length // 2 + 1 - stft.size(1) + stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) + wav = istft(stft, self.win_length, hop_length=self.hop_length, + window=self.win.to(stft.device)) + return wav.detach() + + def separate(self, wav): + """ + Separates stereo wav into different tracks corresponding to different instruments + + Args: + wav (tensor): B x L + """ + + # stft - B X F x L x 2 + # stft_mag - B X F x L + stft, stft_mag = self.compute_stft(wav) + + L = stft.size(2) + + stft_mag = stft_mag.unsqueeze(1).repeat(1,2,1,1) # B x 2 x F x T + stft_mag = pad_and_partition(stft_mag, self.T) # B x 2 x F x T + stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F + + # compute instruments' mask + masks = [] + for net in self.instruments: + mask = net(stft_mag) + masks.append(mask) + + # compute denominator + mask_sum = sum([m ** 2 for m in masks]) + mask_sum += 1e-10 + + wavs = [] + for mask in masks: + mask = (mask ** 2 + 1e-10/2)/(mask_sum) + mask = mask.transpose(2, 3) # B x 2 X F x T + + mask = torch.cat( + torch.split(mask, 1, dim=0), dim=3) + + mask = mask[:,0,:,:L].unsqueeze(-1) # 2 x F x L x 1 + stft_masked = stft * mask + wav_masked = self.inverse_stft(stft_masked) + + wavs.append(wav_masked) + + return wavs \ No newline at end of file diff --git a/codes/models/spleeter/separator.py b/codes/models/spleeter/separator.py new file mode 100644 index 00000000..ceafd17f --- /dev/null +++ b/codes/models/spleeter/separator.py @@ -0,0 +1,32 @@ +import torch +import torch.nn.functional as F + +from models.spleeter.estimator import Estimator + + +class Separator: + def __init__(self, model_path, input_sr=44100, device='cuda'): + self.model = Estimator(2, model_path).to(device) + self.device = device + self.input_sr = input_sr + + def separate(self, npwav, normalize=False): + if not isinstance(npwav, torch.Tensor): + assert len(npwav.shape) == 1 + wav = torch.tensor(npwav, device=self.device) + wav = wav.view(1,-1) + else: + assert len(npwav.shape) == 2 # Input should be BxL + wav = npwav.to(self.device) + + if normalize: + wav = wav / (wav.max() + 1e-8) + + # Spleeter expects audio input to be 44.1kHz. + wav = F.interpolate(wav.unsqueeze(1), mode='nearest', scale_factor=44100/self.input_sr).squeeze(1) + res = self.model.separate(wav) + res = [F.interpolate(r.unsqueeze(1), mode='nearest', scale_factor=self.input_sr/44100)[:,0] for r in res] + return { + 'vocals': res[0].cpu().numpy(), + 'accompaniment': res[1].cpu().numpy() + } \ No newline at end of file diff --git a/codes/models/spleeter/unet.py b/codes/models/spleeter/unet.py new file mode 100644 index 00000000..21c36805 --- /dev/null +++ b/codes/models/spleeter/unet.py @@ -0,0 +1,80 @@ +import torch +from torch import nn + + +def down_block(in_filters, out_filters): + return nn.Conv2d(in_filters, out_filters, kernel_size=5, + stride=2, padding=2, + ), nn.Sequential( + nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01), + nn.LeakyReLU(0.2) + ) + + +def up_block(in_filters, out_filters, dropout=False): + layers = [ + nn.ConvTranspose2d(in_filters, out_filters, kernel_size=5, + stride=2, padding=2, output_padding=1 + ), + nn.ReLU(), + nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01) + ] + if dropout: + layers.append(nn.Dropout(0.5)) + + return nn.Sequential(*layers) + + +class UNet(nn.Module): + def __init__(self, in_channels=2): + super(UNet, self).__init__() + self.down1_conv, self.down1_act = down_block(in_channels, 16) + self.down2_conv, self.down2_act = down_block(16, 32) + self.down3_conv, self.down3_act = down_block(32, 64) + self.down4_conv, self.down4_act = down_block(64, 128) + self.down5_conv, self.down5_act = down_block(128, 256) + self.down6_conv, self.down6_act = down_block(256, 512) + + self.up1 = up_block(512, 256, dropout=True) + self.up2 = up_block(512, 128, dropout=True) + self.up3 = up_block(256, 64, dropout=True) + self.up4 = up_block(128, 32) + self.up5 = up_block(64, 16) + self.up6 = up_block(32, 1) + self.up7 = nn.Sequential( + nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3), + nn.Sigmoid() + ) + + def forward(self, x): + d1_conv = self.down1_conv(x) + d1 = self.down1_act(d1_conv) + + d2_conv = self.down2_conv(d1) + d2 = self.down2_act(d2_conv) + + d3_conv = self.down3_conv(d2) + d3 = self.down3_act(d3_conv) + + d4_conv = self.down4_conv(d3) + d4 = self.down4_act(d4_conv) + + d5_conv = self.down5_conv(d4) + d5 = self.down5_act(d5_conv) + + d6_conv = self.down6_conv(d5) + d6 = self.down6_act(d6_conv) + + u1 = self.up1(d6) + u2 = self.up2(torch.cat([d5_conv, u1], axis=1)) + u3 = self.up3(torch.cat([d4_conv, u2], axis=1)) + u4 = self.up4(torch.cat([d3_conv, u3], axis=1)) + u5 = self.up5(torch.cat([d2_conv, u4], axis=1)) + u6 = self.up6(torch.cat([d1_conv, u5], axis=1)) + u7 = self.up7(u6) + return u7 * x + + +if __name__ == '__main__': + net = UNet(14) + print(net(torch.rand(1, 14, 20, 48)).shape) \ No newline at end of file diff --git a/codes/models/spleeter/util.py b/codes/models/spleeter/util.py new file mode 100644 index 00000000..036f5a74 --- /dev/null +++ b/codes/models/spleeter/util.py @@ -0,0 +1,91 @@ +import numpy as np +import tensorflow as tf + +from .unet import UNet + + +def tf2pytorch(checkpoint_path, num_instrumments): + tf_vars = {} + init_vars = tf.train.list_variables(checkpoint_path) + # print(init_vars) + for name, shape in init_vars: + try: + # print('Loading TF Weight {} with shape {}'.format(name, shape)) + data = tf.train.load_variable(checkpoint_path, name) + tf_vars[name] = data + except Exception as e: + print('Load error') + conv_idx = 0 + tconv_idx = 0 + bn_idx = 0 + outputs = [] + for i in range(num_instrumments): + output = {} + outputs.append(output) + + for j in range(1,7): + if conv_idx == 0: + conv_suffix = "" + else: + conv_suffix = "_" + str(conv_idx) + + if bn_idx == 0: + bn_suffix = "" + else: + bn_suffix = "_" + str(bn_idx) + + output['down{}_conv.weight'.format(j)] = np.transpose( + tf_vars["conv2d{}/kernel".format(conv_suffix)], (3, 2, 0, 1)) + # print('conv dtype: ',output['down{}.0.weight'.format(j)].dtype) + output['down{}_conv.bias'.format( + j)] = tf_vars["conv2d{}/bias".format(conv_suffix)] + + output['down{}_act.0.weight'.format( + j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)] + output['down{}_act.0.bias'.format( + j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)] + output['down{}_act.0.running_mean'.format( + j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)] + output['down{}_act.0.running_var'.format( + j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)] + + conv_idx += 1 + bn_idx += 1 + + # up blocks + for j in range(1, 7): + if tconv_idx == 0: + tconv_suffix = "" + else: + tconv_suffix = "_" + str(tconv_idx) + + if bn_idx == 0: + bn_suffix = "" + else: + bn_suffix= "_" + str(bn_idx) + + output['up{}.0.weight'.format(j)] = np.transpose( + tf_vars["conv2d_transpose{}/kernel".format(tconv_suffix)], (3,2,0, 1)) + output['up{}.0.bias'.format( + j)] = tf_vars["conv2d_transpose{}/bias".format(tconv_suffix)] + output['up{}.2.weight'.format( + j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)] + output['up{}.2.bias'.format( + j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)] + output['up{}.2.running_mean'.format( + j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)] + output['up{}.2.running_var'.format( + j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)] + tconv_idx += 1 + bn_idx += 1 + + if conv_idx == 0: + suffix = "" + else: + suffix = "_" + str(conv_idx) + output['up7.0.weight'] = np.transpose( + tf_vars['conv2d{}/kernel'.format(suffix)], (3, 2, 0, 1)) + output['up7.0.bias'] = tf_vars['conv2d{}/bias'.format(suffix)] + conv_idx += 1 + + return outputs \ No newline at end of file diff --git a/codes/scripts/audio/preparation/spleeter_dataset.py b/codes/scripts/audio/preparation/spleeter_dataset.py new file mode 100644 index 00000000..a0a198fc --- /dev/null +++ b/codes/scripts/audio/preparation/spleeter_dataset.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn +from spleeter.audio.adapter import AudioAdapter +from torch.utils.data import Dataset + +from data.util import find_audio_files + + +class SpleeterDataset(Dataset): + def __init__(self, src_dir, sample_rate=22050, max_duration=20): + self.files = find_audio_files(src_dir, include_nonwav=True) + self.audio_loader = AudioAdapter.default() + self.sample_rate = sample_rate + self.max_duration = max_duration + + def __getitem__(self, item): + file = self.files[item] + try: + wave, sample_rate = self.audio_loader.load(file, sample_rate=self.sample_rate) + assert sample_rate == self.sample_rate + wave = wave[:,0] # strip off channels + wave = torch.tensor(wave) + except: + wave = torch.zeros(self.sample_rate * self.max_duration) + print(f"Error with {file}") + original_duration = wave.shape[0] + padding_needed = self.sample_rate * self.max_duration - original_duration + if padding_needed > 0: + wave = nn.functional.pad(wave, (0, padding_needed)) + return { + 'path': file, + 'wave': wave, + 'duration': original_duration, + } + + def __len__(self): + return len(self.files) \ No newline at end of file diff --git a/codes/scripts/audio/preparation/spleeter_split_voice_and_background_2.py b/codes/scripts/audio/preparation/spleeter_split_voice_and_background_2.py new file mode 100644 index 00000000..a647d405 --- /dev/null +++ b/codes/scripts/audio/preparation/spleeter_split_voice_and_background_2.py @@ -0,0 +1,67 @@ +from scipy.io import wavfile +import os + +import numpy as np +from scipy.io import wavfile +from torch.utils.data import DataLoader +from tqdm import tqdm + +from models.spleeter.separator import Separator +from scripts.audio.preparation.spleeter_dataset import SpleeterDataset + + +def main(): + src_dir = 'F:\\split\\podcast-dump0' + output_dir = 'F:\\tmp\\out' + output_dir_bg = 'F:\\tmp\\bg' + output_sample_rate=22050 + batch_size=24 + + dl = DataLoader(SpleeterDataset(src_dir, output_sample_rate), batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) + separator = Separator('pretrained_models/2stems', input_sr=output_sample_rate) + for e, batch in enumerate(tqdm(dl)): + #if e < 406500: + # continue + waves = batch['wave'] + paths = batch['path'] + durations = batch['duration'] + + sep = separator.separate(waves) + for j in range(sep['vocals'].shape[0]): + vocals = sep['vocals'][j] + bg = sep['accompaniment'][j] + vmax = np.abs(vocals).mean() + bmax = np.abs(bg).mean() + + # Only output to the "good" sample dir if the ratio of background noise to vocal noise is high enough. + ratio = vmax / (bmax+.0000001) + if ratio >= 25: # These values were derived empirically + od = output_dir + out_sound = waves[j].cpu().numpy() + elif ratio <= 1: + od = output_dir_bg + out_sound = bg + else: + continue + + # Strip out channels. + if len(out_sound.shape) > 1: + out_sound = out_sound[:, 0] # Just use the first channel. + # Resize to true duration + out_sound = out_sound[:durations[j]] + + # Compile an output path. + path = paths[j] + reld = os.path.relpath(os.path.dirname(path), src_dir) + os.makedirs(os.path.join(od, reld), exist_ok=True) + relp = os.path.relpath(path, src_dir) + output_path = os.path.join(od, relp) + + wavfile.write(output_path, output_sample_rate, out_sound) + + +# Uses torch spleeter to divide audio clips into one of two bins: +# 1. Audio has little to no background noise, saved to "output_dir" +# 2. Audio has a lot of background noise, bg noise split off and saved to "output_dir_bg" +if __name__ == '__main__': + main() diff --git a/codes/scripts/audio/spleeter_split_voice_and_background.py b/codes/scripts/audio/spleeter_split_voice_and_background.py index ef3d32e1..4cfa77d2 100644 --- a/codes/scripts/audio/spleeter_split_voice_and_background.py +++ b/codes/scripts/audio/spleeter_split_voice_and_background.py @@ -25,8 +25,8 @@ if __name__ == '__main__': separator = Separator('spleeter:2stems') files = find_audio_files(src_dir, include_nonwav=True) for e, file in enumerate(tqdm(files)): - if e < 406500: - continue + #if e < 406500: + # continue file_basis = osp.relpath(file, src_dir)\ .replace('/', '_')\ .replace('\\', '_')\