2022-04-15 14:26:11 +00:00
import os
from glob import glob
2022-05-04 02:44:31 +00:00
import librosa
2023-04-01 13:08:31 +00:00
import soundfile as sf
2022-01-28 06:19:29 +00:00
import torch
import torchaudio
2022-02-04 05:18:21 +00:00
import numpy as np
from scipy . io . wavfile import read
2022-01-28 06:19:29 +00:00
2022-05-01 22:24:24 +00:00
from tortoise . utils . stft import STFT
2022-03-22 17:52:46 +00:00
2023-02-10 20:11:56 +00:00
def get_voice_dir ( ) :
2023-02-14 21:20:04 +00:00
target = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , ' ../voices ' )
if not os . path . exists ( target ) :
target = os . path . dirname ( ' ./voices/ ' )
os . makedirs ( target , exist_ok = True )
2022-01-28 06:19:29 +00:00
2023-02-14 21:20:04 +00:00
return target
2022-01-28 06:19:29 +00:00
def load_audio ( audiopath , sampling_rate ) :
if audiopath [ - 4 : ] == ' .wav ' :
2023-02-14 21:20:04 +00:00
audio , lsr = torchaudio . load ( audiopath )
2022-01-28 06:19:29 +00:00
elif audiopath [ - 4 : ] == ' .mp3 ' :
2022-05-04 02:44:31 +00:00
audio , lsr = librosa . load ( audiopath , sr = sampling_rate )
2022-01-28 06:19:29 +00:00
audio = torch . FloatTensor ( audio )
2023-04-01 13:08:31 +00:00
elif audiopath [ - 5 : ] == ' .flac ' :
audio , lsr = sf . read ( audiopath )
audio = torch . FloatTensor ( audio )
2022-05-13 02:15:40 +00:00
else :
assert False , f " Unsupported audio format provided: { audiopath [ - 4 : ] } "
2022-01-28 06:19:29 +00:00
# Remove any channel data.
if len ( audio . shape ) > 1 :
if audio . shape [ 0 ] < 5 :
audio = audio [ 0 ]
else :
assert audio . shape [ 1 ] < 5
audio = audio [ : , 0 ]
if lsr != sampling_rate :
audio = torchaudio . functional . resample ( audio , lsr , sampling_rate )
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch . any ( audio > 2 ) or not torch . any ( audio < 0 ) :
print ( f " Error with { audiopath } . Max= { audio . max ( ) } min= { audio . min ( ) } " )
audio . clip_ ( - 1 , 1 )
2022-03-22 17:52:46 +00:00
return audio . unsqueeze ( 0 )
TACOTRON_MEL_MAX = 2.3143386840820312
TACOTRON_MEL_MIN = - 11.512925148010254
def denormalize_tacotron_mel ( norm_mel ) :
return ( ( norm_mel + 1 ) / 2 ) * ( TACOTRON_MEL_MAX - TACOTRON_MEL_MIN ) + TACOTRON_MEL_MIN
def normalize_tacotron_mel ( mel ) :
return 2 * ( ( mel - TACOTRON_MEL_MIN ) / ( TACOTRON_MEL_MAX - TACOTRON_MEL_MIN ) ) - 1
def dynamic_range_compression ( x , C = 1 , clip_val = 1e-5 ) :
"""
PARAMS
- - - - - -
C : compression factor
"""
return torch . log ( torch . clamp ( x , min = clip_val ) * C )
def dynamic_range_decompression ( x , C = 1 ) :
"""
PARAMS
- - - - - -
C : compression factor used to compress
"""
return torch . exp ( x ) / C
2023-02-17 04:50:02 +00:00
def get_voices ( extra_voice_dirs = [ ] , load_latents = True ) :
2023-02-14 21:20:04 +00:00
dirs = [ get_voice_dir ( ) ] + extra_voice_dirs
2022-04-15 14:26:11 +00:00
voices = { }
2022-05-19 11:35:57 +00:00
for d in dirs :
subs = os . listdir ( d )
for sub in subs :
subj = os . path . join ( d , sub )
if os . path . isdir ( subj ) :
2023-04-01 13:08:31 +00:00
voices [ sub ] = list ( glob ( f ' { subj } /*.wav ' ) ) + list ( glob ( f ' { subj } /*.mp3 ' ) ) + list ( glob ( f ' { subj } /*.flac ' ) )
2023-02-17 04:50:02 +00:00
if load_latents :
voices [ sub ] = voices [ sub ] + list ( glob ( f ' { subj } /*.pth ' ) )
2022-04-15 14:26:11 +00:00
return voices
2023-03-02 00:44:42 +00:00
def load_voice ( voice , extra_voice_dirs = [ ] , load_latents = True , sample_rate = 22050 , device = ' cpu ' , model_hash = None ) :
2022-05-02 21:40:03 +00:00
if voice == ' random ' :
return None , None
2023-02-17 04:50:02 +00:00
voices = get_voices ( extra_voice_dirs = extra_voice_dirs , load_latents = load_latents )
2023-02-04 01:50:57 +00:00
2023-03-02 00:44:42 +00:00
paths = voices [ voice ]
2023-02-04 01:50:57 +00:00
mtime = 0
2023-03-02 00:44:42 +00:00
2023-02-04 01:50:57 +00:00
latent = None
2023-03-02 00:44:42 +00:00
voices = [ ]
for path in paths :
filename = os . path . basename ( path )
if filename [ - 4 : ] == " .pth " and filename [ : 12 ] == " cond_latents " :
if not model_hash and filename == " cond_latents.pth " :
latent = path
elif model_hash and filename == f " cond_latents_ { model_hash [ : 8 ] } .pth " :
latent = path
2023-02-04 01:50:57 +00:00
else :
2023-03-02 00:44:42 +00:00
voices . append ( path )
mtime = max ( mtime , os . path . getmtime ( path ) )
2023-02-04 01:50:57 +00:00
if load_latents and latent is not None :
if os . path . getmtime ( latent ) > mtime :
print ( f " Reading from latent: { latent } " )
2023-02-09 01:53:25 +00:00
return None , torch . load ( latent , map_location = device )
2023-02-04 01:50:57 +00:00
print ( f " Latent file out of date: { latent } " )
2023-02-14 21:20:04 +00:00
samples = [ ]
for path in voices :
c = load_audio ( path , sample_rate )
samples . append ( c )
return samples , None
2022-05-01 23:25:18 +00:00
2022-05-19 11:35:57 +00:00
def load_voices ( voices , extra_voice_dirs = [ ] ) :
2022-05-01 23:25:18 +00:00
latents = [ ]
clips = [ ]
for voice in voices :
2022-05-02 22:44:47 +00:00
if voice == ' random ' :
2022-05-28 05:25:23 +00:00
if len ( voices ) > 1 :
print ( " Cannot combine a random voice with a non-random voice. Just using a random voice. " )
2022-05-02 22:44:47 +00:00
return None , None
2022-05-19 11:35:57 +00:00
clip , latent = load_voice ( voice , extra_voice_dirs )
2022-05-01 23:25:18 +00:00
if latent is None :
assert len ( latents ) == 0 , " Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this. "
clips . extend ( clip )
2022-05-17 15:34:54 +00:00
elif clip is None :
assert len ( clips ) == 0 , " Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this. "
2022-05-01 23:25:18 +00:00
latents . append ( latent )
if len ( latents ) == 0 :
2022-05-02 22:44:47 +00:00
return clips , None
2022-05-01 23:25:18 +00:00
else :
2022-05-17 15:34:54 +00:00
latents_0 = torch . stack ( [ l [ 0 ] for l in latents ] , dim = 0 ) . mean ( dim = 0 )
latents_1 = torch . stack ( [ l [ 1 ] for l in latents ] , dim = 0 ) . mean ( dim = 0 )
latents = ( latents_0 , latents_1 )
return None , latents
2022-05-01 23:25:18 +00:00
2022-03-22 17:52:46 +00:00
class TacotronSTFT ( torch . nn . Module ) :
def __init__ ( self , filter_length = 1024 , hop_length = 256 , win_length = 1024 ,
n_mel_channels = 80 , sampling_rate = 22050 , mel_fmin = 0.0 ,
mel_fmax = 8000.0 ) :
super ( TacotronSTFT , self ) . __init__ ( )
self . n_mel_channels = n_mel_channels
self . sampling_rate = sampling_rate
self . stft_fn = STFT ( filter_length , hop_length , win_length )
from librosa . filters import mel as librosa_mel_fn
mel_basis = librosa_mel_fn (
2022-05-06 06:11:10 +00:00
sr = sampling_rate , n_fft = filter_length , n_mels = n_mel_channels , fmin = mel_fmin , fmax = mel_fmax )
2022-03-22 17:52:46 +00:00
mel_basis = torch . from_numpy ( mel_basis ) . float ( )
self . register_buffer ( ' mel_basis ' , mel_basis )
def spectral_normalize ( self , magnitudes ) :
output = dynamic_range_compression ( magnitudes )
return output
def spectral_de_normalize ( self , magnitudes ) :
output = dynamic_range_decompression ( magnitudes )
return output
def mel_spectrogram ( self , y ) :
""" Computes mel-spectrograms from a batch of waves
PARAMS
- - - - - -
y : Variable ( torch . FloatTensor ) with shape ( B , T ) in range [ - 1 , 1 ]
RETURNS
- - - - - - -
mel_output : torch . FloatTensor of shape ( B , n_mel_channels , T )
"""
assert ( torch . min ( y . data ) > = - 10 )
assert ( torch . max ( y . data ) < = 10 )
y = torch . clip ( y , min = - 1 , max = 1 )
magnitudes , phases = self . stft_fn . transform ( y )
magnitudes = magnitudes . data
mel_output = torch . matmul ( self . mel_basis , magnitudes )
mel_output = self . spectral_normalize ( mel_output )
return mel_output
2023-02-09 01:53:25 +00:00
def wav_to_univnet_mel ( wav , do_normalization = False , device = ' cpu ' , sample_rate = 24000 ) :
2023-02-07 18:34:29 +00:00
stft = TacotronSTFT ( 1024 , 256 , 1024 , 100 , sample_rate , 0 , 12000 )
2022-06-11 11:03:14 +00:00
stft = stft . to ( device )
2022-03-22 17:52:46 +00:00
mel = stft . mel_spectrogram ( wav )
if do_normalization :
mel = normalize_tacotron_mel ( mel )
2022-05-17 15:34:54 +00:00
return mel