@ -22,6 +22,7 @@ import psutil
import yaml
import hashlib
import string
import random
from tqdm import tqdm
import torch
@ -34,7 +35,7 @@ import pandas as pd
from datetime import datetime
from datetime import timedelta
from tortoise . api import TextToSpeech , MODELS , get_model_path , pad_or_truncate
from tortoise . api import TextToSpeech as TorToise_TTS , MODELS , get_model_path , pad_or_truncate
from tortoise . utils . audio import load_audio , load_voice , load_voices , get_voice_dir , get_voices
from tortoise . utils . text import split_and_recombine_text
from tortoise . utils . device import get_device_name , set_device_name , get_device_count , get_device_vram , get_device_batch_size , do_gc
@ -68,6 +69,10 @@ try:
from vall_e . emb . qnt import encode as valle_quantize
from vall_e . emb . g2p import encode as valle_phonemize
from vall_e . inference import TTS as VALLE_TTS
import soundfile
VALLE_ENABLED = True
except Exception as e :
pass
@ -111,6 +116,12 @@ def resample( waveform, input_rate, output_rate=44100 ):
return RESAMPLERS [ key ] ( waveform ) , output_rate
def generate ( * * kwargs ) :
if args . tts_backend == " tortoise " :
return generate_tortoise ( * * kwargs )
if args . tts_backend == " vall-e " :
return generate_valle ( * * kwargs )
def generate_valle ( * * kwargs ) :
parameters = { }
parameters . update ( kwargs )
@ -140,7 +151,298 @@ def generate(**kwargs):
do_gc ( )
voice_samples = None
conditioning_latents = None
conditioning_latents = None
sample_voice = None
voice_cache = { }
def fetch_voice ( voice ) :
voice_dir = f ' ./voices/ { voice } / '
files = [ f ' { voice_dir } / { d } ' for d in os . listdir ( voice_dir ) if d [ - 4 : ] == " .wav " ]
return files
# return random.choice(files)
def get_settings ( override = None ) :
settings = {
' ar_temp ' : float ( parameters [ ' temperature ' ] ) ,
' nar_temp ' : float ( parameters [ ' temperature ' ] ) ,
' max_ar_samples ' : parameters [ ' num_autoregressive_samples ' ] ,
}
# could be better to just do a ternary on everything above, but i am not a professional
selected_voice = voice
if override is not None :
if ' voice ' in override :
selected_voice = override [ ' voice ' ]
for k in override :
if k not in settings :
continue
settings [ k ] = override [ k ]
settings [ ' reference ' ] = fetch_voice ( voice = selected_voice )
return settings
if not parameters [ ' delimiter ' ] :
parameters [ ' delimiter ' ] = " \n "
elif parameters [ ' delimiter ' ] == " \\ n " :
parameters [ ' delimiter ' ] = " \n "
if parameters [ ' delimiter ' ] and parameters [ ' delimiter ' ] != " " and parameters [ ' delimiter ' ] in parameters [ ' text ' ] :
texts = parameters [ ' text ' ] . split ( parameters [ ' delimiter ' ] )
else :
texts = split_and_recombine_text ( parameters [ ' text ' ] )
full_start_time = time . time ( )
outdir = f " { args . results_folder } / { voice } / "
os . makedirs ( outdir , exist_ok = True )
audio_cache = { }
volume_adjust = torchaudio . transforms . Vol ( gain = args . output_volume , gain_type = " amplitude " ) if args . output_volume != 1 else None
idx = 0
idx_cache = { }
for i , file in enumerate ( os . listdir ( outdir ) ) :
filename = os . path . basename ( file )
extension = os . path . splitext ( filename ) [ 1 ]
if extension != " .json " and extension != " .wav " :
continue
match = re . findall ( rf " ^ { voice } _( \ d+)(?:.+?)? { extension } $ " , filename )
if match and len ( match ) > 0 :
key = int ( match [ 0 ] )
idx_cache [ key ] = True
if len ( idx_cache ) > 0 :
keys = sorted ( list ( idx_cache . keys ( ) ) )
idx = keys [ - 1 ] + 1
idx = pad ( idx , 4 )
def get_name ( line = 0 , candidate = 0 , combined = False ) :
name = f " { idx } "
if combined :
name = f " { name } _combined "
elif len ( texts ) > 1 :
name = f " { name } _ { line } "
if parameters [ ' candidates ' ] > 1 :
name = f " { name } _ { candidate } "
return name
def get_info ( voice , settings = None , latents = True ) :
info = { }
info . update ( parameters )
info [ ' time ' ] = time . time ( ) - full_start_time
info [ ' datetime ' ] = datetime . now ( ) . isoformat ( )
info [ ' progress ' ] = None
del info [ ' progress ' ]
if info [ ' delimiter ' ] == " \n " :
info [ ' delimiter ' ] = " \\ n "
if settings is not None :
for k in settings :
if k in info :
info [ k ] = settings [ k ]
return info
INFERENCING = True
for line , cut_text in enumerate ( texts ) :
progress . msg_prefix = f ' [ { str ( line + 1 ) } / { str ( len ( texts ) ) } ] '
print ( f " { progress . msg_prefix } Generating line: { cut_text } " )
start_time = time . time ( )
# do setting editing
match = re . findall ( r ' ^( \ { .+ \ }) (.+?)$ ' , cut_text )
override = None
if match and len ( match ) > 0 :
match = match [ 0 ]
try :
override = json . loads ( match [ 0 ] )
cut_text = match [ 1 ] . strip ( )
except Exception as e :
raise Exception ( " Prompt settings editing requested, but received invalid JSON " )
settings = get_settings ( override = override )
reference = settings [ ' reference ' ]
settings . pop ( " reference " )
gen = tts . inference ( cut_text , reference , * * settings )
run_time = time . time ( ) - start_time
print ( f " Generating line took { run_time } seconds " )
if not isinstance ( gen , list ) :
gen = [ gen ]
for j , g in enumerate ( gen ) :
wav , sr = g
name = get_name ( line = line , candidate = j )
settings [ ' text ' ] = cut_text
settings [ ' time ' ] = run_time
settings [ ' datetime ' ] = datetime . now ( ) . isoformat ( )
# save here in case some error happens mid-batch
#torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr)
soundfile . write ( f ' { outdir } / { voice } _ { name } .wav ' , wav . cpu ( ) [ 0 , 0 ] , sr )
wav , sr = torchaudio . load ( f ' { outdir } / { voice } _ { name } .wav ' )
audio_cache [ name ] = {
' audio ' : wav ,
' settings ' : get_info ( voice = override [ ' voice ' ] if override and ' voice ' in override else voice , settings = settings )
}
del gen
do_gc ( )
INFERENCING = False
for k in audio_cache :
audio = audio_cache [ k ] [ ' audio ' ]
audio , _ = resample ( audio , tts . output_sample_rate , args . output_sample_rate )
if volume_adjust is not None :
audio = volume_adjust ( audio )
audio_cache [ k ] [ ' audio ' ] = audio
torchaudio . save ( f ' { outdir } / { voice } _ { k } .wav ' , audio , args . output_sample_rate )
output_voices = [ ]
for candidate in range ( parameters [ ' candidates ' ] ) :
if len ( texts ) > 1 :
audio_clips = [ ]
for line in range ( len ( texts ) ) :
name = get_name ( line = line , candidate = candidate )
audio = audio_cache [ name ] [ ' audio ' ]
audio_clips . append ( audio )
name = get_name ( candidate = candidate , combined = True )
audio = torch . cat ( audio_clips , dim = - 1 )
torchaudio . save ( f ' { outdir } / { voice } _ { name } .wav ' , audio , args . output_sample_rate )
audio = audio . squeeze ( 0 ) . cpu ( )
audio_cache [ name ] = {
' audio ' : audio ,
' settings ' : get_info ( voice = voice ) ,
' output ' : True
}
else :
name = get_name ( candidate = candidate )
audio_cache [ name ] [ ' output ' ] = True
if args . voice_fixer :
if not voicefixer :
progress ( 0 , " Loading voicefix... " )
load_voicefixer ( )
try :
fixed_cache = { }
for name in progress . tqdm ( audio_cache , desc = " Running voicefix... " ) :
del audio_cache [ name ] [ ' audio ' ]
if ' output ' not in audio_cache [ name ] or not audio_cache [ name ] [ ' output ' ] :
continue
path = f ' { outdir } / { voice } _ { name } .wav '
fixed = f ' { outdir } / { voice } _ { name } _fixed.wav '
voicefixer . restore (
input = path ,
output = fixed ,
cuda = get_device_name ( ) == " cuda " and args . voice_fixer_use_cuda ,
#mode=mode,
)
fixed_cache [ f ' { name } _fixed ' ] = {
' settings ' : audio_cache [ name ] [ ' settings ' ] ,
' output ' : True
}
audio_cache [ name ] [ ' output ' ] = False
for name in fixed_cache :
audio_cache [ name ] = fixed_cache [ name ]
except Exception as e :
print ( e )
print ( " \n Failed to run Voicefixer " )
for name in audio_cache :
if ' output ' not in audio_cache [ name ] or not audio_cache [ name ] [ ' output ' ] :
if args . prune_nonfinal_outputs :
audio_cache [ name ] [ ' pruned ' ] = True
os . remove ( f ' { outdir } / { voice } _ { name } .wav ' )
continue
output_voices . append ( f ' { outdir } / { voice } _ { name } .wav ' )
if not args . embed_output_metadata :
with open ( f ' { outdir } / { voice } _ { name } .json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( audio_cache [ name ] [ ' settings ' ] , indent = ' \t ' ) )
if args . embed_output_metadata :
for name in progress . tqdm ( audio_cache , desc = " Embedding metadata... " ) :
if ' pruned ' in audio_cache [ name ] and audio_cache [ name ] [ ' pruned ' ] :
continue
metadata = music_tag . load_file ( f " { outdir } / { voice } _ { name } .wav " )
metadata [ ' lyrics ' ] = json . dumps ( audio_cache [ name ] [ ' settings ' ] )
metadata . save ( )
if sample_voice is not None :
sample_voice = ( tts . input_sample_rate , sample_voice . numpy ( ) )
info = get_info ( voice = voice , latents = False )
print ( f " Generation took { info [ ' time ' ] } seconds, saved to ' { output_voices [ 0 ] } ' \n " )
info [ ' seed ' ] = usedSeed
if ' latents ' in info :
del info [ ' latents ' ]
os . makedirs ( ' ./config/ ' , exist_ok = True )
with open ( f ' ./config/generate.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( info , indent = ' \t ' ) )
stats = [
[ parameters [ ' seed ' ] , " {:.3f} " . format ( info [ ' time ' ] ) ]
]
return (
sample_voice ,
output_voices ,
stats ,
)
def generate_tortoise ( * * kwargs ) :
parameters = { }
parameters . update ( kwargs )
voice = parameters [ ' voice ' ]
progress = parameters [ ' progress ' ] if ' progress ' in parameters else None
if parameters [ ' seed ' ] == 0 :
parameters [ ' seed ' ] = None
usedSeed = parameters [ ' seed ' ]
global args
global tts
unload_whisper ( )
unload_voicefixer ( )
if not tts :
# should check if it's loading or unloaded, and load it if it's unloaded
if tts_loading :
raise Exception ( " TTS is still initializing... " )
load_tts ( )
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
do_gc ( )
voice_samples = None
conditioning_latents = None
sample_voice = None
voice_cache = { }
@ -295,11 +597,13 @@ def generate(**kwargs):
def get_info ( voice , settings = None , latents = True ) :
info = { }
info . update ( parameters )
info [ ' time ' ] = time . time ( ) - full_start_time
info [ ' time ' ] = time . time ( ) - full_start_time
info [ ' datetime ' ] = datetime . now ( ) . isoformat ( )
info [ ' model ' ] = tts . autoregressive_model_path
info [ ' model_hash ' ] = tts . autoregressive_model_hash
info [ ' progress ' ] = None
del info [ ' progress ' ]
@ -381,9 +685,10 @@ def generate(**kwargs):
settings [ ' text ' ] = cut_text
settings [ ' time ' ] = run_time
settings [ ' datetime ' ] = datetime . now ( ) . isoformat ( ) ,
settings [ ' model ' ] = tts . autoregressive_model_path
settings [ ' model_hash ' ] = tts . autoregressive_model_hash
settings [ ' datetime ' ] = datetime . now ( ) . isoformat ( )
if args . tts_backend == " tortoise " :
settings [ ' model ' ] = tts . autoregressive_model_path
settings [ ' model_hash ' ] = tts . autoregressive_model_hash
audio_cache [ name ] = {
' audio ' : audio ,
@ -745,8 +1050,8 @@ class TrainingState():
self . it_rate = f ' { " {:.3f} " . format ( 1 / it_rate ) } it/s ' if 0 < it_rate and it_rate < 1 else f ' { " {:.3f} " . format ( it_rate ) } s/it '
self . it_rates + = it_rate
epoch_rate = self . it_rates / self . it * self . steps
if epoch_rate > 0 :
if self . it_rates > 0 and self . it * self . steps > 0 :
epoch_rate = self . it_rates / self . it * self . steps
self . epoch_rate = f ' { " {:.3f} " . format ( 1 / epoch_rate ) } epoch/s ' if 0 < epoch_rate and epoch_rate < 1 else f ' { " {:.3f} " . format ( epoch_rate ) } s/epoch '
try :
@ -925,6 +1230,7 @@ class TrainingState():
self . it_rates = 0
unq = { }
averager = None
for log in logs :
with open ( log , ' r ' , encoding = " utf-8 " ) as f :
@ -941,16 +1247,18 @@ class TrainingState():
if line . find ( ' Training Metrics: ' ) > = 0 :
split = line . split ( " Training Metrics: " ) [ - 1 ]
data = json . loads ( split )
data [ ' mode ' ] = " training "
name = " train "
mode = " training "
elif line . find ( ' Validation Metrics: ' ) > = 0 :
data = json . loads ( line . split ( " Validation Metrics: " ) [ - 1 ] )
data [ ' mode ' ] = " validation "
if " it " not in data :
data [ ' it ' ] = it
if " epoch " not in data :
data [ ' epoch ' ] = epoch
name = data [ ' name ' ] if ' name ' in data else " val "
mode = " validation "
else :
continue
@ -960,14 +1268,39 @@ class TrainingState():
it = data [ ' it ' ]
epoch = data [ ' epoch ' ]
# this method should have it at least
unq [ f ' { it } _ { name } ' ] = data
if args . tts_backend == " vall-e " :
if not averager or averager [ ' key ' ] != f ' { it } _ { name } ' or averager [ ' mode ' ] != mode :
averager = {
' key ' : f ' { it } _ { name } ' ,
' mode ' : mode ,
" metrics " : { }
}
for k in data :
if data [ k ] is None :
continue
averager [ ' metrics ' ] [ k ] = [ data [ k ] ]
else :
for k in data :
if data [ k ] is None :
continue
averager [ ' metrics ' ] [ k ] . append ( data [ k ] )
unq [ f ' { it } _ { mode } _ { name } ' ] = averager
else :
unq [ f ' { it } _ { mode } _ { name } ' ] = data
if update and it < = self . last_info_check_at :
continue
for it in unq :
self . parse_metrics ( unq [ it ] )
if args . tts_backend == " vall-e " :
stats = unq [ it ]
data = { k : sum ( v ) / len ( v ) for k , v in stats [ ' metrics ' ] . items ( ) }
data [ ' mode ' ] = stats
data [ ' steps ' ] = len ( stats [ ' metrics ' ] [ ' it ' ] )
else :
data = unq [ it ]
self . parse_metrics ( data )
self . last_info_check_at = highest_step
@ -1087,7 +1420,8 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
# ensure we have the dvae.pth
get_model_path ( ' dvae.pth ' )
if args . tts_backend == " tortoise " :
get_model_path ( ' dvae.pth ' )
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
torch . multiprocessing . freeze_support ( )
@ -2086,6 +2420,8 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
res = res + defaults
return res
def get_valle_models ( dir = " ./training/ " ) :
return [ f ' { dir } / { d } /config.yaml ' for d in os . listdir ( dir ) if os . path . exists ( f ' { dir } / { d } /config.yaml ' ) ]
def get_autoregressive_models ( dir = " ./models/finetunes/ " , prefixed = False , auto = False ) :
os . makedirs ( dir , exist_ok = True )
@ -2268,6 +2604,8 @@ def setup_args():
' tokenizer-json ' : None ,
' phonemizer-backend ' : ' espeak ' ,
' valle-model ' : None ,
' whisper-backend ' : ' openai/whisper ' ,
' whisper-model ' : " base " ,
@ -2319,6 +2657,8 @@ def setup_args():
parser . add_argument ( " --phonemizer-backend " , default = default_arguments [ ' phonemizer-backend ' ] , help = " Specifies which phonemizer backend to use. " )
parser . add_argument ( " --valle-model " , default = default_arguments [ ' valle-model ' ] , help = " Specifies which VALL-E model to use for sampling. " )
parser . add_argument ( " --whisper-backend " , default = default_arguments [ ' whisper-backend ' ] , action = ' store_true ' , help = " Picks which whisper backend to use (openai/whisper, lightmare/whispercpp) " )
parser . add_argument ( " --whisper-model " , default = default_arguments [ ' whisper-model ' ] , help = " Specifies which whisper model to use for transcription. " )
parser . add_argument ( " --whisper-batchsize " , type = int , default = default_arguments [ ' whisper-batchsize ' ] , help = " Specifies batch size for WhisperX " )
@ -2389,6 +2729,8 @@ def get_default_settings( hypenated=True ):
' tokenizer-json ' : args . tokenizer_json ,
' phonemizer-backend ' : args . phonemizer_backend ,
' valle-model ' : args . valle_model ,
' whisper-backend ' : args . whisper_backend ,
' whisper-model ' : args . whisper_model ,
@ -2439,6 +2781,8 @@ def update_args( **kwargs ):
args . tokenizer_json = settings [ ' tokenizer_json ' ]
args . phonemizer_backend = settings [ ' phonemizer_backend ' ]
args . valle_model = settings [ ' valle_model ' ]
args . whisper_backend = settings [ ' whisper_backend ' ]
args . whisper_model = settings [ ' whisper_model ' ]
@ -2553,49 +2897,60 @@ def version_check_tts( min_version ):
return True
return False
def load_tts ( restart = False , autoregressive_model = None , diffusion_model = None , vocoder_model = None , tokenizer_json = None ) :
def load_tts ( restart = False ,
# TorToiSe configs
autoregressive_model = None , diffusion_model = None , vocoder_model = None , tokenizer_json = None ,
# VALL-E configs
valle_model = None ,
) :
global args
global tts
if restart :
unload_tts ( )
if autoregressive_model :
args . autoregressive_model = autoregressive_model
else :
autoregressive_model = args . autoregressive_model
tts_loading = True
if args . tts_backend == " tortoise " :
if autoregressive_model :
args . autoregressive_model = autoregressive_model
else :
autoregressive_model = args . autoregressive_model
if autoregressive_model == " auto " :
autoregressive_model = deduce_autoregressive_model ( )
if autoregressive_model == " auto " :
autoregressive_model = deduce_autoregressive_model ( )
if diffusion_model :
args . diffusion_model = diffusion_model
else :
diffusion_model = args . diffusion_model
if diffusion_model :
args . diffusion_model = diffusion_model
else :
diffusion_model = args . diffusion_model
if vocoder_model :
args . vocoder_model = vocoder_model
else :
vocoder_model = args . vocoder_model
if vocoder_model :
args . vocoder_model = vocoder_model
else :
vocoder_model = args . vocoder_model
if tokenizer_json :
args . tokenizer_json = tokenizer_json
else :
tokenizer_json = args . tokenizer_json
if tokenizer_json :
args . tokenizer_json = tokenizer_json
else :
tokenizer_json = args . tokenizer_json
if get_device_name ( ) == " cpu " :
print ( " !!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch. " )
if get_device_name ( ) == " cpu " :
print ( " !!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch. " )
tts_loading = True
print ( f " Loading TorToiSe... (AR: { autoregressive_model } , vocoder: { vocoder_model } ) " )
tts = TextToSpeech ( minor_optimizations = not args . low_vram , autoregressive_model_path = autoregressive_model , diffusion_model_path = diffusion_model , vocoder_model = vocoder_model , tokenizer_json = tokenizer_json , unsqueeze_sample_batches = args . unsqueeze_sample_batches )
tts_loading = False
print ( f " Loading TorToiSe... (AR: { autoregressive_model } , diffusion: { diffusion_model } , vocoder: { vocoder_model } ) " )
tts = TorToise_TTS ( minor_optimizations = not args . low_vram , autoregressive_model_path = autoregressive_model , diffusion_model_path = diffusion_model , vocoder_model = vocoder_model , tokenizer_json = tokenizer_json , unsqueeze_sample_batches = args . unsqueeze_sample_batches )
elif args . tts_backend == " vall-e " :
if valle_model :
args . valle_model = valle_model
else :
valle_model = args . valle_model
get_model_path ( ' dvae.pth ' )
print ( " Loaded TorToiSe, ready for generation. " )
return tts
print ( f " Loading VALL-E... (Config: { valle_model } ) " )
tts = VALLE_TTS ( config = args . valle_model )
setup_tortoise = load_tts
print ( " Loaded TTS, ready for generation. " )
tts_loading = False
return tts
def unload_tts ( ) :
global tts
@ -2643,6 +2998,9 @@ def deduce_autoregressive_model(voice=None):
return get_model_path ( ' autoregressive.pth ' )
def update_autoregressive_model ( autoregressive_model_path ) :
if args . tts_backend != " tortoise " :
raise f " Unsupported backend: { args . tts_backend } "
match = re . findall ( r ' ^ \ [[a-fA-F0-9] {8} \ ] (.+?)$ ' , autoregressive_model_path )
if match :
autoregressive_model_path = match [ 0 ]
@ -2677,6 +3035,9 @@ def update_autoregressive_model(autoregressive_model_path):
return autoregressive_model_path
def update_diffusion_model ( diffusion_model_path ) :
if args . tts_backend != " tortoise " :
raise f " Unsupported backend: { args . tts_backend } "
match = re . findall ( r ' ^ \ [[a-fA-F0-9] {8} \ ] (.+?)$ ' , diffusion_model_path )
if match :
diffusion_model_path = match [ 0 ]
@ -2711,6 +3072,9 @@ def update_diffusion_model(diffusion_model_path):
return diffusion_model_path
def update_vocoder_model ( vocoder_model ) :
if args . tts_backend != " tortoise " :
raise f " Unsupported backend: { args . tts_backend } "
args . vocoder_model = vocoder_model
save_args_settings ( )
print ( f ' Stored vocoder model to settings: { vocoder_model } ' )
@ -2733,6 +3097,9 @@ def update_vocoder_model(vocoder_model):
return vocoder_model
def update_tokenizer ( tokenizer_json ) :
if args . tts_backend != " tortoise " :
raise f " Unsupported backend: { args . tts_backend } "
args . tokenizer_json = tokenizer_json
save_args_settings ( )
print ( f ' Stored tokenizer to settings: { tokenizer_json } ' )