DL-Art-School/codes/models/spleeter/estimator.py
2021-09-09 23:14:32 -06:00

137 lines
3.7 KiB
Python

import math
import torch
import torch.nn.functional as F
from torch import nn
from torch import istft
from .unet import UNet
from .util import tf2pytorch
def load_ckpt(model, ckpt):
state_dict = model.state_dict()
for k, v in ckpt.items():
if k in state_dict:
target_shape = state_dict[k].shape
assert target_shape == v.shape
state_dict.update({k: torch.from_numpy(v)})
else:
print('Ignore ', k)
model.load_state_dict(state_dict)
return model
def pad_and_partition(tensor, T):
"""
pads zero and partition tensor into segments of length T
Args:
tensor(Tensor): BxCxFxL
Returns:
tensor of size (B*[L/T] x C x F x T)
"""
old_size = tensor.size(3)
new_size = math.ceil(old_size/T) * T
tensor = F.pad(tensor, [0, new_size - old_size])
[b, c, t, f] = tensor.shape
split = new_size // T
return torch.cat(torch.split(tensor, T, dim=3), dim=0)
class Estimator(nn.Module):
def __init__(self, num_instrumments, checkpoint_path):
super(Estimator, self).__init__()
# stft config
self.F = 1024
self.T = 512
self.win_length = 4096
self.hop_length = 1024
self.win = torch.hann_window(self.win_length)
ckpts = tf2pytorch(checkpoint_path, num_instrumments)
# filter
self.instruments = nn.ModuleList()
for i in range(num_instrumments):
print('Loading model for instrumment {}'.format(i))
net = UNet(2)
ckpt = ckpts[i]
net = load_ckpt(net, ckpt)
net.eval() # change mode to eval
self.instruments.append(net)
def compute_stft(self, wav):
"""
Computes stft feature from wav
Args:
wav (Tensor): B x L
"""
stft = torch.stft(
wav, self.win_length, hop_length=self.hop_length, window=self.win.to(wav.device))
# only keep freqs smaller than self.F
stft = stft[:, :self.F, :, :]
real = stft[:, :, :, 0]
im = stft[:, :, :, 1]
mag = torch.sqrt(real ** 2 + im ** 2)
return stft, mag
def inverse_stft(self, stft):
"""Inverses stft to wave form"""
pad = self.win_length // 2 + 1 - stft.size(1)
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
wav = istft(stft, self.win_length, hop_length=self.hop_length,
window=self.win.to(stft.device))
return wav.detach()
def separate(self, wav):
"""
Separates stereo wav into different tracks corresponding to different instruments
Args:
wav (tensor): B x L
"""
# stft - B X F x L x 2
# stft_mag - B X F x L
stft, stft_mag = self.compute_stft(wav)
L = stft.size(2)
stft_mag = stft_mag.unsqueeze(1).repeat(1,2,1,1) # B x 2 x F x T
stft_mag = pad_and_partition(stft_mag, self.T) # B x 2 x F x T
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
# compute instruments' mask
masks = []
for net in self.instruments:
mask = net(stft_mag)
masks.append(mask)
# compute denominator
mask_sum = sum([m ** 2 for m in masks])
mask_sum += 1e-10
wavs = []
for mask in masks:
mask = (mask ** 2 + 1e-10/2)/(mask_sum)
mask = mask.transpose(2, 3) # B x 2 X F x T
mask = torch.cat(
torch.split(mask, 1, dim=0), dim=3)
mask = mask[:,0,:,:L].unsqueeze(-1) # 2 x F x L x 1
stft_masked = stft * mask
wav_masked = self.inverse_stft(stft_masked)
wavs.append(wav_masked)
return wavs