2022-04-15 14:26:11 +00:00
import os
from glob import glob
2022-05-04 02:44:31 +00:00
import librosa
2022-01-28 06:19:29 +00:00
import torch
import torchaudio
2022-02-04 05:18:21 +00:00
import numpy as np
from scipy . io . wavfile import read
2022-01-28 06:19:29 +00:00
2022-05-01 22:24:24 +00:00
from tortoise . utils . stft import STFT
2022-03-22 17:52:46 +00:00
2022-01-28 06:19:29 +00:00
2022-05-29 01:10:19 +00:00
BUILTIN_VOICES_DIR = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , ' ../voices ' )
2022-01-28 06:19:29 +00:00
def load_wav_to_torch ( full_path ) :
sampling_rate , data = read ( full_path )
if data . dtype == np . int32 :
norm_fix = 2 * * 31
elif data . dtype == np . int16 :
norm_fix = 2 * * 15
elif data . dtype == np . float16 or data . dtype == np . float32 :
norm_fix = 1.
else :
raise NotImplemented ( f " Provided data dtype not supported: { data . dtype } " )
return ( torch . FloatTensor ( data . astype ( np . float32 ) ) / norm_fix , sampling_rate )
def load_audio ( audiopath , sampling_rate ) :
if audiopath [ - 4 : ] == ' .wav ' :
audio , lsr = load_wav_to_torch ( audiopath )
elif audiopath [ - 4 : ] == ' .mp3 ' :
2022-05-04 02:44:31 +00:00
audio , lsr = librosa . load ( audiopath , sr = sampling_rate )
2022-01-28 06:19:29 +00:00
audio = torch . FloatTensor ( audio )
2022-05-13 02:15:40 +00:00
else :
assert False , f " Unsupported audio format provided: { audiopath [ - 4 : ] } "
2022-01-28 06:19:29 +00:00
# Remove any channel data.
if len ( audio . shape ) > 1 :
if audio . shape [ 0 ] < 5 :
audio = audio [ 0 ]
else :
assert audio . shape [ 1 ] < 5
audio = audio [ : , 0 ]
if lsr != sampling_rate :
audio = torchaudio . functional . resample ( audio , lsr , sampling_rate )
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch . any ( audio > 2 ) or not torch . any ( audio < 0 ) :
print ( f " Error with { audiopath } . Max= { audio . max ( ) } min= { audio . min ( ) } " )
audio . clip_ ( - 1 , 1 )
2022-03-22 17:52:46 +00:00
return audio . unsqueeze ( 0 )
TACOTRON_MEL_MAX = 2.3143386840820312
TACOTRON_MEL_MIN = - 11.512925148010254
def denormalize_tacotron_mel ( norm_mel ) :
return ( ( norm_mel + 1 ) / 2 ) * ( TACOTRON_MEL_MAX - TACOTRON_MEL_MIN ) + TACOTRON_MEL_MIN
def normalize_tacotron_mel ( mel ) :
return 2 * ( ( mel - TACOTRON_MEL_MIN ) / ( TACOTRON_MEL_MAX - TACOTRON_MEL_MIN ) ) - 1
def dynamic_range_compression ( x , C = 1 , clip_val = 1e-5 ) :
"""
PARAMS
- - - - - -
C : compression factor
"""
return torch . log ( torch . clamp ( x , min = clip_val ) * C )
def dynamic_range_decompression ( x , C = 1 ) :
"""
PARAMS
- - - - - -
C : compression factor used to compress
"""
return torch . exp ( x ) / C
2022-05-19 11:35:57 +00:00
def get_voices ( extra_voice_dirs = [ ] ) :
2022-05-29 01:10:19 +00:00
dirs = [ BUILTIN_VOICES_DIR ] + extra_voice_dirs
2022-04-15 14:26:11 +00:00
voices = { }
2022-05-19 11:35:57 +00:00
for d in dirs :
subs = os . listdir ( d )
for sub in subs :
subj = os . path . join ( d , sub )
if os . path . isdir ( subj ) :
voices [ sub ] = list ( glob ( f ' { subj } /*.wav ' ) ) + list ( glob ( f ' { subj } /*.mp3 ' ) ) + list ( glob ( f ' { subj } /*.pth ' ) )
2022-04-15 14:26:11 +00:00
return voices
2022-05-19 11:35:57 +00:00
def load_voice ( voice , extra_voice_dirs = [ ] ) :
2022-05-02 21:40:03 +00:00
if voice == ' random ' :
return None , None
2022-05-19 11:35:57 +00:00
voices = get_voices ( extra_voice_dirs )
2022-05-01 23:25:18 +00:00
paths = voices [ voice ]
if len ( paths ) == 1 and paths [ 0 ] . endswith ( ' .pth ' ) :
return None , torch . load ( paths [ 0 ] )
else :
conds = [ ]
for cond_path in paths :
c = load_audio ( cond_path , 22050 )
conds . append ( c )
return conds , None
2022-05-19 11:35:57 +00:00
def load_voices ( voices , extra_voice_dirs = [ ] ) :
2022-05-01 23:25:18 +00:00
latents = [ ]
clips = [ ]
for voice in voices :
2022-05-02 22:44:47 +00:00
if voice == ' random ' :
2022-05-28 05:25:23 +00:00
if len ( voices ) > 1 :
print ( " Cannot combine a random voice with a non-random voice. Just using a random voice. " )
2022-05-02 22:44:47 +00:00
return None , None
2022-05-19 11:35:57 +00:00
clip , latent = load_voice ( voice , extra_voice_dirs )
2022-05-01 23:25:18 +00:00
if latent is None :
assert len ( latents ) == 0 , " Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this. "
clips . extend ( clip )
2022-05-17 15:34:54 +00:00
elif clip is None :
assert len ( clips ) == 0 , " Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this. "
2022-05-01 23:25:18 +00:00
latents . append ( latent )
if len ( latents ) == 0 :
2022-05-02 22:44:47 +00:00
return clips , None
2022-05-01 23:25:18 +00:00
else :
2022-05-17 15:34:54 +00:00
latents_0 = torch . stack ( [ l [ 0 ] for l in latents ] , dim = 0 ) . mean ( dim = 0 )
latents_1 = torch . stack ( [ l [ 1 ] for l in latents ] , dim = 0 ) . mean ( dim = 0 )
latents = ( latents_0 , latents_1 )
return None , latents
2022-05-01 23:25:18 +00:00
2022-03-22 17:52:46 +00:00
class TacotronSTFT ( torch . nn . Module ) :
def __init__ ( self , filter_length = 1024 , hop_length = 256 , win_length = 1024 ,
n_mel_channels = 80 , sampling_rate = 22050 , mel_fmin = 0.0 ,
mel_fmax = 8000.0 ) :
super ( TacotronSTFT , self ) . __init__ ( )
self . n_mel_channels = n_mel_channels
self . sampling_rate = sampling_rate
self . stft_fn = STFT ( filter_length , hop_length , win_length )
from librosa . filters import mel as librosa_mel_fn
mel_basis = librosa_mel_fn (
2022-05-06 06:11:10 +00:00
sr = sampling_rate , n_fft = filter_length , n_mels = n_mel_channels , fmin = mel_fmin , fmax = mel_fmax )
2022-03-22 17:52:46 +00:00
mel_basis = torch . from_numpy ( mel_basis ) . float ( )
self . register_buffer ( ' mel_basis ' , mel_basis )
def spectral_normalize ( self , magnitudes ) :
output = dynamic_range_compression ( magnitudes )
return output
def spectral_de_normalize ( self , magnitudes ) :
output = dynamic_range_decompression ( magnitudes )
return output
def mel_spectrogram ( self , y ) :
""" Computes mel-spectrograms from a batch of waves
PARAMS
- - - - - -
y : Variable ( torch . FloatTensor ) with shape ( B , T ) in range [ - 1 , 1 ]
RETURNS
- - - - - - -
mel_output : torch . FloatTensor of shape ( B , n_mel_channels , T )
"""
assert ( torch . min ( y . data ) > = - 10 )
assert ( torch . max ( y . data ) < = 10 )
y = torch . clip ( y , min = - 1 , max = 1 )
magnitudes , phases = self . stft_fn . transform ( y )
magnitudes = magnitudes . data
mel_output = torch . matmul ( self . mel_basis , magnitudes )
mel_output = self . spectral_normalize ( mel_output )
return mel_output
2022-06-11 11:03:14 +00:00
def wav_to_univnet_mel ( wav , do_normalization = False , device = ' cuda ' ) :
2022-03-22 17:52:46 +00:00
stft = TacotronSTFT ( 1024 , 256 , 1024 , 100 , 24000 , 0 , 12000 )
2022-06-11 11:03:14 +00:00
stft = stft . to ( device )
2022-03-22 17:52:46 +00:00
mel = stft . mel_spectrogram ( wav )
if do_normalization :
mel = normalize_tacotron_mel ( mel )
2022-05-17 15:34:54 +00:00
return mel