import torch import torch.nn.functional as F from models.spleeter.estimator import Estimator class Separator: def __init__(self, model_path, input_sr=44100, device='cuda'): self.model = Estimator(2, model_path).to(device) self.device = device self.input_sr = input_sr def separate(self, npwav, normalize=False): if not isinstance(npwav, torch.Tensor): assert len(npwav.shape) == 1 wav = torch.tensor(npwav, device=self.device) wav = wav.view(1,-1) else: assert len(npwav.shape) == 2 # Input should be BxL wav = npwav.to(self.device) if normalize: wav = wav / (wav.max() + 1e-8) # Spleeter expects audio input to be 44.1kHz. wav = F.interpolate(wav.unsqueeze(1), mode='nearest', scale_factor=44100/self.input_sr).squeeze(1) res = self.model.separate(wav) res = [F.interpolate(r.unsqueeze(1), mode='nearest', scale_factor=self.input_sr/44100)[:,0] for r in res] return { 'vocals': res[0].cpu().numpy(), 'accompaniment': res[1].cpu().numpy() }