2023-02-17 00:08:27 +00:00
import os
if ' XDG_CACHE_HOME ' not in os . environ :
os . environ [ ' XDG_CACHE_HOME ' ] = os . path . realpath ( os . path . join ( os . getcwd ( ) , ' ./models/ ' ) )
if ' TORTOISE_MODELS_DIR ' not in os . environ :
os . environ [ ' TORTOISE_MODELS_DIR ' ] = os . path . realpath ( os . path . join ( os . getcwd ( ) , ' ./models/tortoise/ ' ) )
if ' TRANSFORMERS_CACHE ' not in os . environ :
os . environ [ ' TRANSFORMERS_CACHE ' ] = os . path . realpath ( os . path . join ( os . getcwd ( ) , ' ./models/transformers/ ' ) )
import argparse
import time
import json
import base64
import re
import urllib . request
2023-02-18 02:07:22 +00:00
import signal
2023-02-18 20:37:37 +00:00
import gc
2023-02-17 00:08:27 +00:00
2023-02-18 02:07:22 +00:00
import tqdm
2023-02-17 00:08:27 +00:00
import torch
import torchaudio
import music_tag
import gradio as gr
import gradio . utils
from datetime import datetime
2023-02-17 19:06:05 +00:00
from tortoise . api import TextToSpeech , MODELS , get_model_path
2023-02-17 00:08:27 +00:00
from tortoise . utils . audio import load_audio , load_voice , load_voices , get_voice_dir
from tortoise . utils . text import split_and_recombine_text
from tortoise . utils . device import get_device_name , set_device_name
2023-02-17 05:42:55 +00:00
import whisper
2023-02-17 00:08:27 +00:00
2023-02-17 19:06:05 +00:00
MODELS [ ' dvae.pth ' ] = " https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth "
2023-02-17 00:08:27 +00:00
args = None
tts = None
webui = None
voicefixer = None
2023-02-17 19:06:05 +00:00
whisper_model = None
2023-02-17 00:08:27 +00:00
2023-02-18 20:37:37 +00:00
def do_gc ( ) :
gc . collect ( )
2023-02-17 00:08:27 +00:00
def get_args ( ) :
global args
return args
def setup_args ( ) :
global args
default_arguments = {
' share ' : False ,
' listen ' : None ,
' check-for-updates ' : False ,
' models-from-local-only ' : False ,
' low-vram ' : False ,
' sample-batch-size ' : None ,
' embed-output-metadata ' : True ,
' latents-lean-and-mean ' : True ,
2023-02-17 20:10:27 +00:00
' voice-fixer ' : False , # getting tired of long initialization times in a Colab for downloading a large dataset for it
2023-02-17 00:08:27 +00:00
' voice-fixer-use-cuda ' : True ,
' force-cpu-for-conditioning-latents ' : False ,
2023-02-17 20:43:12 +00:00
' defer-tts-load ' : False ,
2023-02-17 00:08:27 +00:00
' device-override ' : None ,
2023-02-17 06:01:14 +00:00
' whisper-model ' : " base " ,
2023-02-18 14:10:26 +00:00
' autoregressive-model ' : None ,
2023-02-17 00:08:27 +00:00
' concurrency-count ' : 2 ,
' output-sample-rate ' : 44100 ,
' output-volume ' : 1 ,
}
if os . path . isfile ( ' ./config/exec.json ' ) :
with open ( f ' ./config/exec.json ' , ' r ' , encoding = " utf-8 " ) as f :
overrides = json . load ( f )
for k in overrides :
default_arguments [ k ] = overrides [ k ]
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --share " , action = ' store_true ' , default = default_arguments [ ' share ' ] , help = " Lets Gradio return a public URL to use anywhere " )
parser . add_argument ( " --listen " , default = default_arguments [ ' listen ' ] , help = " Path for Gradio to listen on " )
parser . add_argument ( " --check-for-updates " , action = ' store_true ' , default = default_arguments [ ' check-for-updates ' ] , help = " Checks for update on startup " )
parser . add_argument ( " --models-from-local-only " , action = ' store_true ' , default = default_arguments [ ' models-from-local-only ' ] , help = " Only loads models from disk, does not check for updates for models " )
parser . add_argument ( " --low-vram " , action = ' store_true ' , default = default_arguments [ ' low-vram ' ] , help = " Disables some optimizations that increases VRAM usage " )
parser . add_argument ( " --no-embed-output-metadata " , action = ' store_false ' , default = not default_arguments [ ' embed-output-metadata ' ] , help = " Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag) " )
parser . add_argument ( " --latents-lean-and-mean " , action = ' store_true ' , default = default_arguments [ ' latents-lean-and-mean ' ] , help = " Exports the bare essentials for latents. " )
parser . add_argument ( " --voice-fixer " , action = ' store_true ' , default = default_arguments [ ' voice-fixer ' ] , help = " Uses python module ' voicefixer ' to improve audio quality, if available. " )
parser . add_argument ( " --voice-fixer-use-cuda " , action = ' store_true ' , default = default_arguments [ ' voice-fixer-use-cuda ' ] , help = " Hints to voicefixer to use CUDA, if available. " )
parser . add_argument ( " --force-cpu-for-conditioning-latents " , default = default_arguments [ ' force-cpu-for-conditioning-latents ' ] , action = ' store_true ' , help = " Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts) " )
2023-02-17 20:43:12 +00:00
parser . add_argument ( " --defer-tts-load " , default = default_arguments [ ' defer-tts-load ' ] , action = ' store_true ' , help = " Defers loading TTS model " )
2023-02-17 00:08:27 +00:00
parser . add_argument ( " --device-override " , default = default_arguments [ ' device-override ' ] , help = " A device string to override pass through Torch " )
2023-02-17 06:01:14 +00:00
parser . add_argument ( " --whisper-model " , default = default_arguments [ ' whisper-model ' ] , help = " Specifies which whisper model to use for transcription. " )
2023-02-18 14:10:26 +00:00
parser . add_argument ( " --autoregressive-model " , default = default_arguments [ ' autoregressive-model ' ] , help = " Specifies which autoregressive model to use for sampling. " )
2023-02-17 00:08:27 +00:00
parser . add_argument ( " --sample-batch-size " , default = default_arguments [ ' sample-batch-size ' ] , type = int , help = " Sets how many batches to use during the autoregressive samples pass " )
parser . add_argument ( " --concurrency-count " , type = int , default = default_arguments [ ' concurrency-count ' ] , help = " How many Gradio events to process at once " )
parser . add_argument ( " --output-sample-rate " , type = int , default = default_arguments [ ' output-sample-rate ' ] , help = " Sample rate to resample the output to (from 24KHz) " )
parser . add_argument ( " --output-volume " , type = float , default = default_arguments [ ' output-volume ' ] , help = " Adjusts volume of output " )
2023-02-18 02:07:22 +00:00
parser . add_argument ( " --os " , default = " unix " , help = " Specifies which OS, easily " )
2023-02-17 00:08:27 +00:00
args = parser . parse_args ( )
args . embed_output_metadata = not args . no_embed_output_metadata
set_device_name ( args . device_override )
args . listen_host = None
args . listen_port = None
args . listen_path = None
if args . listen :
try :
match = re . findall ( r " ^(?:(.+?):( \ d+))?( \ /.+?)?$ " , args . listen ) [ 0 ]
args . listen_host = match [ 0 ] if match [ 0 ] != " " else " 127.0.0.1 "
args . listen_port = match [ 1 ] if match [ 1 ] != " " else None
args . listen_path = match [ 2 ] if match [ 2 ] != " " else " / "
except Exception as e :
pass
if args . listen_port is not None :
args . listen_port = int ( args . listen_port )
return args
2023-02-17 05:42:55 +00:00
def pad ( num , zeroes ) :
2023-02-18 19:09:25 +00:00
return str ( num ) . zfill ( zeroes + 1 )
2023-02-17 05:42:55 +00:00
2023-02-17 00:08:27 +00:00
def generate (
text ,
delimiter ,
emotion ,
prompt ,
voice ,
mic_audio ,
voice_latents_chunks ,
seed ,
candidates ,
num_autoregressive_samples ,
diffusion_iterations ,
temperature ,
diffusion_sampler ,
breathing_room ,
cvvp_weight ,
top_p ,
diffusion_temperature ,
length_penalty ,
repetition_penalty ,
cond_free_k ,
experimental_checkboxes ,
progress = None
) :
global args
global tts
2023-02-18 14:10:26 +00:00
if not tts :
raise Exception ( " TTS is uninitialized or still initializing... " )
2023-02-17 00:08:27 +00:00
2023-02-18 20:37:37 +00:00
do_gc ( )
2023-02-17 00:08:27 +00:00
if voice != " microphone " :
voices = [ voice ]
else :
voices = [ ]
if voice == " microphone " :
if mic_audio is None :
2023-02-17 03:05:27 +00:00
raise Exception ( " Please provide audio from mic when choosing `microphone` as a voice input " )
2023-02-17 00:08:27 +00:00
mic = load_audio ( mic_audio , tts . input_sample_rate )
voice_samples , conditioning_latents = [ mic ] , None
elif voice == " random " :
voice_samples , conditioning_latents = None , tts . get_random_conditioning_latents ( )
else :
progress ( 0 , desc = " Loading voice... " )
voice_samples , conditioning_latents = load_voice ( voice )
2023-02-18 17:23:44 +00:00
if voice_samples and len ( voice_samples ) > 0 :
2023-02-17 00:08:27 +00:00
sample_voice = torch . cat ( voice_samples , dim = - 1 ) . squeeze ( ) . cpu ( )
conditioning_latents = tts . get_conditioning_latents ( voice_samples , return_mels = not args . latents_lean_and_mean , progress = progress , slices = voice_latents_chunks , force_cpu = args . force_cpu_for_conditioning_latents )
if len ( conditioning_latents ) == 4 :
conditioning_latents = ( conditioning_latents [ 0 ] , conditioning_latents [ 1 ] , conditioning_latents [ 2 ] , None )
if voice != " microphone " :
torch . save ( conditioning_latents , f ' { get_voice_dir ( ) } / { voice } /cond_latents.pth ' )
voice_samples = None
else :
if conditioning_latents is not None :
sample_voice , _ = load_voice ( voice , load_latents = False )
2023-02-18 17:23:44 +00:00
if sample_voice and len ( sample_voice ) > 0 :
sample_voice = torch . cat ( sample_voice , dim = - 1 ) . squeeze ( ) . cpu ( )
2023-02-17 00:08:27 +00:00
else :
sample_voice = None
if seed == 0 :
seed = None
if conditioning_latents is not None and len ( conditioning_latents ) == 2 and cvvp_weight > 0 :
print ( " Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents. " )
cvvp_weight = 0
settings = {
' temperature ' : float ( temperature ) ,
' top_p ' : float ( top_p ) ,
' diffusion_temperature ' : float ( diffusion_temperature ) ,
' length_penalty ' : float ( length_penalty ) ,
' repetition_penalty ' : float ( repetition_penalty ) ,
' cond_free_k ' : float ( cond_free_k ) ,
' num_autoregressive_samples ' : num_autoregressive_samples ,
' sample_batch_size ' : args . sample_batch_size ,
' diffusion_iterations ' : diffusion_iterations ,
' voice_samples ' : voice_samples ,
' conditioning_latents ' : conditioning_latents ,
' use_deterministic_seed ' : seed ,
' return_deterministic_state ' : True ,
' k ' : candidates ,
' diffusion_sampler ' : diffusion_sampler ,
' breathing_room ' : breathing_room ,
' progress ' : progress ,
' half_p ' : " Half Precision " in experimental_checkboxes ,
' cond_free ' : " Conditioning-Free " in experimental_checkboxes ,
' cvvp_amount ' : cvvp_weight ,
}
2023-02-19 06:02:47 +00:00
if delimiter is None :
2023-02-19 06:04:46 +00:00
delimiter = " \n "
elif delimiter == " \\ n " :
2023-02-17 00:08:27 +00:00
delimiter = " \n "
2023-02-19 06:02:47 +00:00
if delimiter and delimiter != " " and delimiter in text :
2023-02-17 00:08:27 +00:00
texts = text . split ( delimiter )
else :
texts = split_and_recombine_text ( text )
full_start_time = time . time ( )
outdir = f " ./results/ { voice } / "
os . makedirs ( outdir , exist_ok = True )
audio_cache = { }
resample = None
# not a ternary in the event for some reason I want to rely on librosa's upsampling interpolator rather than torchaudio's, for some reason
if tts . output_sample_rate != args . output_sample_rate :
resampler = torchaudio . transforms . Resample (
tts . output_sample_rate ,
args . output_sample_rate ,
lowpass_filter_width = 16 ,
rolloff = 0.85 ,
resampling_method = " kaiser_window " ,
beta = 8.555504641634386 ,
)
volume_adjust = torchaudio . transforms . Vol ( gain = args . output_volume , gain_type = " amplitude " ) if args . output_volume != 1 else None
idx = 0
idx_cache = { }
for i , file in enumerate ( os . listdir ( outdir ) ) :
filename = os . path . basename ( file )
extension = os . path . splitext ( filename ) [ 1 ]
if extension != " .json " and extension != " .wav " :
continue
match = re . findall ( rf " ^ { voice } _( \ d+)(?:.+?)? { extension } $ " , filename )
key = int ( match [ 0 ] )
idx_cache [ key ] = True
if len ( idx_cache ) > 0 :
keys = sorted ( list ( idx_cache . keys ( ) ) )
idx = keys [ - 1 ] + 1
2023-02-17 05:42:55 +00:00
idx = pad ( idx , 4 )
2023-02-17 00:08:27 +00:00
def get_name ( line = 0 , candidate = 0 , combined = False ) :
name = f " { idx } "
if combined :
name = f " { name } _combined "
elif len ( texts ) > 1 :
name = f " { name } _ { line } "
if candidates > 1 :
name = f " { name } _ { candidate } "
return name
for line , cut_text in enumerate ( texts ) :
if emotion == " Custom " :
if prompt . strip ( ) != " " :
cut_text = f " [ { prompt } ,] { cut_text } "
else :
cut_text = f " [I am really { emotion . lower ( ) } ,] { cut_text } "
progress . msg_prefix = f ' [ { str ( line + 1 ) } / { str ( len ( texts ) ) } ] '
print ( f " { progress . msg_prefix } Generating line: { cut_text } " )
start_time = time . time ( )
gen , additionals = tts . tts ( cut_text , * * settings )
seed = additionals [ 0 ]
run_time = time . time ( ) - start_time
print ( f " Generating line took { run_time } seconds " )
if not isinstance ( gen , list ) :
gen = [ gen ]
for j , g in enumerate ( gen ) :
audio = g . squeeze ( 0 ) . cpu ( )
name = get_name ( line = line , candidate = j )
audio_cache [ name ] = {
' audio ' : audio ,
' text ' : cut_text ,
' time ' : run_time
}
# save here in case some error happens mid-batch
torchaudio . save ( f ' { outdir } / { voice } _ { name } .wav ' , audio , tts . output_sample_rate )
2023-02-18 20:37:37 +00:00
del gen
do_gc ( )
2023-02-17 00:08:27 +00:00
for k in audio_cache :
audio = audio_cache [ k ] [ ' audio ' ]
if resampler is not None :
audio = resampler ( audio )
if volume_adjust is not None :
audio = volume_adjust ( audio )
audio_cache [ k ] [ ' audio ' ] = audio
torchaudio . save ( f ' { outdir } / { voice } _ { k } .wav ' , audio , args . output_sample_rate )
output_voices = [ ]
for candidate in range ( candidates ) :
if len ( texts ) > 1 :
audio_clips = [ ]
for line in range ( len ( texts ) ) :
name = get_name ( line = line , candidate = candidate )
audio = audio_cache [ name ] [ ' audio ' ]
audio_clips . append ( audio )
name = get_name ( candidate = candidate , combined = True )
audio = torch . cat ( audio_clips , dim = - 1 )
torchaudio . save ( f ' { outdir } / { voice } _ { name } .wav ' , audio , args . output_sample_rate )
audio = audio . squeeze ( 0 ) . cpu ( )
audio_cache [ name ] = {
' audio ' : audio ,
' text ' : text ,
' time ' : time . time ( ) - full_start_time ,
' output ' : True
}
else :
name = get_name ( candidate = candidate )
audio_cache [ name ] [ ' output ' ] = True
info = {
' text ' : text ,
2023-02-19 06:02:47 +00:00
' delimiter ' : ' \\ n ' if delimiter and delimiter == " \n " else delimiter ,
2023-02-17 00:08:27 +00:00
' emotion ' : emotion ,
' prompt ' : prompt ,
' voice ' : voice ,
' seed ' : seed ,
' candidates ' : candidates ,
' num_autoregressive_samples ' : num_autoregressive_samples ,
' diffusion_iterations ' : diffusion_iterations ,
' temperature ' : temperature ,
' diffusion_sampler ' : diffusion_sampler ,
' breathing_room ' : breathing_room ,
' cvvp_weight ' : cvvp_weight ,
' top_p ' : top_p ,
' diffusion_temperature ' : diffusion_temperature ,
' length_penalty ' : length_penalty ,
' repetition_penalty ' : repetition_penalty ,
' cond_free_k ' : cond_free_k ,
' experimentals ' : experimental_checkboxes ,
' time ' : time . time ( ) - full_start_time ,
}
# kludgy yucky codesmells
for name in audio_cache :
if ' output ' not in audio_cache [ name ] :
continue
output_voices . append ( f ' { outdir } / { voice } _ { name } .wav ' )
with open ( f ' { outdir } / { voice } _ { name } .json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( info , indent = ' \t ' ) )
2023-02-17 20:10:27 +00:00
if args . voice_fixer and voicefixer is not None :
2023-02-17 00:08:27 +00:00
fixed_output_voices = [ ]
for path in progress . tqdm ( output_voices , desc = " Running voicefix... " ) :
fixed = path . replace ( " .wav " , " _fixed.wav " )
voicefixer . restore (
input = path ,
output = fixed ,
cuda = get_device_name ( ) == " cuda " and args . voice_fixer_use_cuda ,
#mode=mode,
)
fixed_output_voices . append ( fixed )
output_voices = fixed_output_voices
2023-02-18 17:23:44 +00:00
if voice and voice != " random " and conditioning_latents is not None :
2023-02-17 00:08:27 +00:00
with open ( f ' { get_voice_dir ( ) } / { voice } /cond_latents.pth ' , ' rb ' ) as f :
info [ ' latents ' ] = base64 . b64encode ( f . read ( ) ) . decode ( " ascii " )
if args . embed_output_metadata :
for name in progress . tqdm ( audio_cache , desc = " Embedding metadata... " ) :
info [ ' text ' ] = audio_cache [ name ] [ ' text ' ]
info [ ' time ' ] = audio_cache [ name ] [ ' time ' ]
metadata = music_tag . load_file ( f " { outdir } / { voice } _ { name } .wav " )
metadata [ ' lyrics ' ] = json . dumps ( info )
metadata . save ( )
if sample_voice is not None :
sample_voice = ( tts . input_sample_rate , sample_voice . numpy ( ) )
print ( f " Generation took { info [ ' time ' ] } seconds, saved to ' { output_voices [ 0 ] } ' \n " )
info [ ' seed ' ] = settings [ ' use_deterministic_seed ' ]
if ' latents ' in info :
del info [ ' latents ' ]
2023-02-17 20:10:27 +00:00
os . makedirs ( ' ./config/ ' , exist_ok = True )
2023-02-17 00:08:27 +00:00
with open ( f ' ./config/generate.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( info , indent = ' \t ' ) )
stats = [
[ seed , " {:.3f} " . format ( info [ ' time ' ] ) ]
]
return (
sample_voice ,
output_voices ,
stats ,
)
2023-02-17 20:10:27 +00:00
import subprocess
2023-02-18 02:07:22 +00:00
training_process = None
2023-02-19 05:05:30 +00:00
def run_training ( config_path , verbose = False , buffer_size = 8 , progress = gr . Progress ( track_tqdm = True ) ) :
2023-02-18 02:40:30 +00:00
try :
print ( " Unloading TTS to save VRAM. " )
global tts
del tts
tts = None
2023-02-19 05:05:30 +00:00
trytorch . cuda . empty_cache ( )
2023-02-18 02:40:30 +00:00
except Exception as e :
pass
2023-02-17 19:06:05 +00:00
2023-02-18 02:07:22 +00:00
global training_process
torch . multiprocessing . freeze_support ( )
2023-02-19 05:05:30 +00:00
do_gc ( )
2023-02-17 20:10:27 +00:00
2023-02-18 14:14:42 +00:00
cmd = [ ' train.bat ' , config_path ] if os . name == " nt " else [ ' bash ' , ' ./train.sh ' , config_path ]
2023-02-17 20:10:27 +00:00
print ( " Spawning process: " , " " . join ( cmd ) )
2023-02-18 02:07:22 +00:00
training_process = subprocess . Popen ( cmd , stdout = subprocess . PIPE , stderr = subprocess . STDOUT , universal_newlines = True )
2023-02-19 05:05:30 +00:00
# parse config to get its iteration
import yaml
with open ( config_path , ' r ' ) as file :
config = yaml . safe_load ( file )
it = 0
its = config [ ' train ' ] [ ' niter ' ]
checkpoint = 0
checkpoints = config [ ' logger ' ] [ ' save_checkpoint_freq ' ]
buffer_size = 8
open_state = False
training_started = False
yield " " . join ( cmd )
buffer = [ ]
infos = [ ]
yields = True
2023-02-18 02:07:22 +00:00
for line in iter ( training_process . stdout . readline , " " ) :
2023-02-19 05:05:30 +00:00
buffer . append ( f ' { line } ' )
# rip out iteration info
if not training_started :
if line . find ( ' Start training from epoch ' ) > = 0 :
training_started = True
elif progress is not None :
if line . find ( ' 0 % | ' ) == 0 :
open_state = True
elif line . find ( ' 100 % | ' ) == 0 and open_state :
open_state = False
it = it + 1
progress ( it / float ( its ) , f ' [ { it } / { its } ] Training... ' )
elif line . find ( ' INFO: [epoch: ' ) > = 0 :
infos . append ( f ' { line } ' )
elif line . find ( ' Saving models and training states ' ) > = 0 :
checkpoint = checkpoint + 1
progress ( checkpoint / float ( checkpoints ) , f ' [ { checkpoint } / { checkpoints } ] Saving checkpoint... ' )
2023-02-19 06:28:53 +00:00
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
2023-02-19 05:05:30 +00:00
if verbose :
yield " " . join ( buffer [ - buffer_size : ] )
2023-02-18 02:07:22 +00:00
training_process . stdout . close ( )
return_code = training_process . wait ( )
training_process = None
2023-02-19 05:05:30 +00:00
2023-02-18 03:31:44 +00:00
#if return_code:
# raise subprocess.CalledProcessError(return_code, cmd)
2023-02-18 02:07:22 +00:00
2023-02-19 05:05:30 +00:00
return " " . join ( buffer [ - buffer_size : ] )
2023-02-18 02:07:22 +00:00
def stop_training ( ) :
2023-02-19 05:12:09 +00:00
global training_process
2023-02-18 02:07:22 +00:00
if training_process is None :
return " No training in progress "
training_process . kill ( )
training_process = None
return " Training cancelled "
2023-02-17 19:06:05 +00:00
def setup_voicefixer ( restart = False ) :
global voicefixer
if restart :
del voicefixer
voicefixer = None
try :
print ( " Initializating voice-fixer " )
from voicefixer import VoiceFixer
voicefixer = VoiceFixer ( )
print ( " initialized voice-fixer " )
except Exception as e :
print ( f " Error occurred while tring to initialize voicefixer: { e } " )
2023-02-17 00:08:27 +00:00
def setup_tortoise ( restart = False ) :
global args
global tts
2023-02-18 20:37:37 +00:00
do_gc ( )
if args . voice_fixer :
2023-02-17 19:06:05 +00:00
setup_voicefixer ( restart = restart )
if restart :
del tts
tts = None
2023-02-17 00:08:27 +00:00
2023-02-18 20:37:37 +00:00
print ( f " Initializating TorToiSe... (using model: { args . autoregressive_model } ) " )
2023-02-19 05:10:08 +00:00
try :
tts = TextToSpeech ( minor_optimizations = not args . low_vram , autoregressive_model_path = args . autoregressive_model )
except Exception as e :
tts = TextToSpeech ( minor_optimizations = not args . low_vram )
load_autoregressive_model ( args . autoregressive_model )
2023-02-17 19:06:05 +00:00
get_model_path ( ' dvae.pth ' )
2023-02-17 00:08:27 +00:00
print ( " TorToiSe initialized, ready for generation. " )
2023-02-17 03:05:27 +00:00
return tts
2023-02-18 20:37:37 +00:00
def compute_latents ( voice , voice_latents_chunks , progress = gr . Progress ( track_tqdm = True ) ) :
global tts
global args
if not tts :
raise Exception ( " TTS is uninitialized or still initializing... " )
do_gc ( )
voice_samples , conditioning_latents = load_voice ( voice , load_latents = False )
if voice_samples is None :
return
conditioning_latents = tts . get_conditioning_latents ( voice_samples , return_mels = not args . latents_lean_and_mean , progress = progress , slices = voice_latents_chunks , force_cpu = args . force_cpu_for_conditioning_latents )
if len ( conditioning_latents ) == 4 :
conditioning_latents = ( conditioning_latents [ 0 ] , conditioning_latents [ 1 ] , conditioning_latents [ 2 ] , None )
torch . save ( conditioning_latents , f ' { get_voice_dir ( ) } / { voice } /cond_latents.pth ' )
return voice
2023-02-18 15:50:51 +00:00
def save_training_settings ( iterations = None , batch_size = None , learning_rate = None , print_rate = None , save_rate = None , name = None , dataset_name = None , dataset_path = None , validation_name = None , validation_path = None , output_name = None ) :
2023-02-17 03:05:27 +00:00
settings = {
2023-02-18 15:50:51 +00:00
" iterations " : iterations if iterations else 500 ,
" batch_size " : batch_size if batch_size else 64 ,
2023-02-17 03:05:27 +00:00
" learning_rate " : learning_rate if learning_rate else 1e-5 ,
" print_rate " : print_rate if print_rate else 50 ,
" save_rate " : save_rate if save_rate else 50 ,
" name " : name if name else " finetune " ,
" dataset_name " : dataset_name if dataset_name else " finetune " ,
2023-02-17 13:57:03 +00:00
" dataset_path " : dataset_path if dataset_path else " ./training/finetune/train.txt " ,
2023-02-17 03:05:27 +00:00
" validation_name " : validation_name if validation_name else " finetune " ,
2023-02-17 13:57:03 +00:00
" validation_path " : validation_path if validation_path else " ./training/finetune/train.txt " ,
2023-02-17 03:05:27 +00:00
}
2023-02-18 14:51:00 +00:00
if not output_name :
output_name = f ' { settings [ " name " ] } .yaml '
outfile = f ' ./training/ { output_name } '
2023-02-17 03:05:27 +00:00
2023-02-18 02:07:22 +00:00
with open ( f ' ./models/.template.yaml ' , ' r ' , encoding = " utf-8 " ) as f :
2023-02-17 03:05:27 +00:00
yaml = f . read ( )
for k in settings :
yaml = yaml . replace ( f " $ {{ { k } }} " , str ( settings [ k ] ) )
2023-02-18 02:07:22 +00:00
with open ( outfile , ' w ' , encoding = " utf-8 " ) as f :
2023-02-17 03:05:27 +00:00
f . write ( yaml )
2023-02-18 02:07:22 +00:00
return f " Training settings saved to: { outfile } "
def prepare_dataset ( files , outdir , language = None , progress = None ) :
2023-02-17 05:42:55 +00:00
global whisper_model
if whisper_model is None :
2023-02-18 02:07:22 +00:00
notify_progress ( f " Loading Whisper model: { args . whisper_model } " , progress )
2023-02-17 06:01:14 +00:00
whisper_model = whisper . load_model ( args . whisper_model )
2023-02-17 05:42:55 +00:00
os . makedirs ( outdir , exist_ok = True )
idx = 0
results = { }
2023-02-17 06:25:00 +00:00
transcription = [ ]
2023-02-17 05:42:55 +00:00
2023-02-18 02:07:22 +00:00
for file in enumerate_progress ( files , desc = " Iterating through voice files " , progress = progress ) :
2023-02-17 05:42:55 +00:00
print ( f " Transcribing file: { file } " )
2023-02-17 20:10:27 +00:00
result = whisper_model . transcribe ( file , language = language if language else " English " )
2023-02-17 05:42:55 +00:00
results [ os . path . basename ( file ) ] = result
print ( f " Transcribed file: { file } , { len ( result [ ' segments ' ] ) } found. " )
waveform , sampling_rate = torchaudio . load ( file )
num_channels , num_frames = waveform . shape
2023-02-18 02:07:22 +00:00
for segment in result [ ' segments ' ] : # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
2023-02-17 06:11:04 +00:00
start = int ( segment [ ' start ' ] * sampling_rate )
end = int ( segment [ ' end ' ] * sampling_rate )
2023-02-17 05:42:55 +00:00
sliced_waveform = waveform [ : , start : end ]
sliced_name = f " { pad ( idx , 4 ) } .wav "
torchaudio . save ( f " { outdir } / { sliced_name } " , sliced_waveform , sampling_rate )
2023-02-17 06:06:50 +00:00
transcription . append ( f " { sliced_name } | { segment [ ' text ' ] . strip ( ) } " )
2023-02-17 05:42:55 +00:00
idx = idx + 1
with open ( f ' { outdir } /whisper.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( results , indent = ' \t ' ) )
with open ( f ' { outdir } /train.txt ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( " \n " . join ( transcription ) )
2023-02-18 02:07:22 +00:00
return f " Processed dataset to: { outdir } "
2023-02-17 03:05:27 +00:00
def reset_generation_settings ( ) :
with open ( f ' ./config/generate.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( { } , indent = ' \t ' ) )
return import_generate_settings ( )
2023-02-18 02:07:22 +00:00
def import_voices ( files , saveAs = None , progress = None ) :
2023-02-17 03:05:27 +00:00
global args
2023-02-18 02:07:22 +00:00
if not isinstance ( files , list ) :
files = [ files ]
2023-02-17 03:05:27 +00:00
2023-02-18 02:07:22 +00:00
for file in enumerate_progress ( files , desc = " Importing voice files " , progress = progress ) :
j , latents = read_generate_settings ( file , read_latents = True )
if j is not None and saveAs is None :
saveAs = j [ ' voice ' ]
if saveAs is None or saveAs == " " :
raise Exception ( " Specify a voice name " )
outdir = f ' { get_voice_dir ( ) } / { saveAs } / '
os . makedirs ( outdir , exist_ok = True )
if latents :
print ( f " Importing latents to { latents } " )
with open ( f ' { outdir } /cond_latents.pth ' , ' wb ' ) as f :
f . write ( latents )
latents = f ' { outdir } /cond_latents.pth '
print ( f " Imported latents to { latents } " )
2023-02-17 03:05:27 +00:00
else :
2023-02-18 02:07:22 +00:00
filename = file . name
if filename [ - 4 : ] != " .wav " :
raise Exception ( " Please convert to a WAV first " )
path = f " { outdir } / { os . path . basename ( filename ) } "
print ( f " Importing voice to { path } " )
waveform , sampling_rate = torchaudio . load ( filename )
if args . voice_fixer and voicefixer is not None :
# resample to best bandwidth since voicefixer will do it anyways through librosa
if sampling_rate != 44100 :
print ( f " Resampling imported voice sample: { path } " )
resampler = torchaudio . transforms . Resample (
sampling_rate ,
44100 ,
lowpass_filter_width = 16 ,
rolloff = 0.85 ,
resampling_method = " kaiser_window " ,
beta = 8.555504641634386 ,
)
waveform = resampler ( waveform )
sampling_rate = 44100
torchaudio . save ( path , waveform , sampling_rate )
print ( f " Running ' voicefixer ' on voice sample: { path } " )
voicefixer . restore (
input = path ,
output = path ,
cuda = get_device_name ( ) == " cuda " and args . voice_fixer_use_cuda ,
#mode=mode,
)
else :
torchaudio . save ( path , waveform , sampling_rate )
2023-02-17 03:05:27 +00:00
2023-02-18 02:07:22 +00:00
print ( f " Imported voice to { path } " )
2023-02-17 03:05:27 +00:00
def import_generate_settings ( file = " ./config/generate.json " ) :
settings , _ = read_generate_settings ( file , read_latents = False )
if settings is None :
return None
return (
None if ' text ' not in settings else settings [ ' text ' ] ,
None if ' delimiter ' not in settings else settings [ ' delimiter ' ] ,
None if ' emotion ' not in settings else settings [ ' emotion ' ] ,
None if ' prompt ' not in settings else settings [ ' prompt ' ] ,
None if ' voice ' not in settings else settings [ ' voice ' ] ,
None ,
None ,
None if ' seed ' not in settings else settings [ ' seed ' ] ,
None if ' candidates ' not in settings else settings [ ' candidates ' ] ,
None if ' num_autoregressive_samples ' not in settings else settings [ ' num_autoregressive_samples ' ] ,
None if ' diffusion_iterations ' not in settings else settings [ ' diffusion_iterations ' ] ,
0.8 if ' temperature ' not in settings else settings [ ' temperature ' ] ,
" DDIM " if ' diffusion_sampler ' not in settings else settings [ ' diffusion_sampler ' ] ,
8 if ' breathing_room ' not in settings else settings [ ' breathing_room ' ] ,
0.0 if ' cvvp_weight ' not in settings else settings [ ' cvvp_weight ' ] ,
0.8 if ' top_p ' not in settings else settings [ ' top_p ' ] ,
1.0 if ' diffusion_temperature ' not in settings else settings [ ' diffusion_temperature ' ] ,
1.0 if ' length_penalty ' not in settings else settings [ ' length_penalty ' ] ,
2.0 if ' repetition_penalty ' not in settings else settings [ ' repetition_penalty ' ] ,
2.0 if ' cond_free_k ' not in settings else settings [ ' cond_free_k ' ] ,
None if ' experimentals ' not in settings else settings [ ' experimentals ' ] ,
)
def curl ( url ) :
try :
req = urllib . request . Request ( url , headers = { ' User-Agent ' : ' Python ' } )
conn = urllib . request . urlopen ( req )
data = conn . read ( )
data = data . decode ( )
data = json . loads ( data )
conn . close ( )
return data
except Exception as e :
print ( e )
return None
def check_for_updates ( ) :
if not os . path . isfile ( ' ./.git/FETCH_HEAD ' ) :
print ( " Cannot check for updates: not from a git repo " )
return False
with open ( f ' ./.git/FETCH_HEAD ' , ' r ' , encoding = " utf-8 " ) as f :
head = f . read ( )
match = re . findall ( r " ^([a-f0-9]+).+?https: \ / \ /(.+?) \ /(.+?) \ /(.+?) \ n " , head )
if match is None or len ( match ) == 0 :
print ( " Cannot check for updates: cannot parse FETCH_HEAD " )
return False
match = match [ 0 ]
local = match [ 0 ]
host = match [ 1 ]
owner = match [ 2 ]
repo = match [ 3 ]
res = curl ( f " https:// { host } /api/v1/repos/ { owner } / { repo } /branches/ " ) #this only works for gitea instances
if res is None or len ( res ) == 0 :
print ( " Cannot check for updates: cannot fetch from remote " )
return False
remote = res [ 0 ] [ " commit " ] [ " id " ]
if remote != local :
print ( f " New version found: { local [ : 8 ] } => { remote [ : 8 ] } " )
return True
return False
def reload_tts ( ) :
2023-02-17 19:06:05 +00:00
setup_tortoise ( restart = True )
2023-02-17 03:05:27 +00:00
def cancel_generate ( ) :
2023-02-18 19:54:21 +00:00
from tortoise . api import STOP_SIGNAL
STOP_SIGNAL = True
2023-02-17 03:05:27 +00:00
def get_voice_list ( dir = get_voice_dir ( ) ) :
os . makedirs ( dir , exist_ok = True )
return sorted ( [ d for d in os . listdir ( dir ) if os . path . isdir ( os . path . join ( dir , d ) ) and len ( os . listdir ( os . path . join ( dir , d ) ) ) > 0 ] ) + [ " microphone " , " random " ]
2023-02-18 19:41:21 +00:00
def get_autoregressive_models ( dir = " ./models/finetunes/ " ) :
2023-02-18 14:10:26 +00:00
os . makedirs ( dir , exist_ok = True )
2023-02-18 19:46:26 +00:00
return [ get_model_path ( ' autoregressive.pth ' ) ] + sorted ( [ f ' { dir } / { d } ' for d in os . listdir ( dir ) if d [ - 4 : ] == " .pth " ] )
2023-02-18 14:10:26 +00:00
2023-02-18 14:51:00 +00:00
def get_dataset_list ( dir = " ./training/ " ) :
return sorted ( [ d for d in os . listdir ( dir ) if os . path . isdir ( os . path . join ( dir , d ) ) and len ( os . listdir ( os . path . join ( dir , d ) ) ) > 0 and " train.txt " in os . listdir ( os . path . join ( dir , d ) ) ] )
2023-02-18 14:10:26 +00:00
2023-02-18 14:51:00 +00:00
def get_training_list ( dir = " ./training/ " ) :
return sorted ( [ f ' ./training/ { d } /train.yaml ' for d in os . listdir ( dir ) if os . path . isdir ( os . path . join ( dir , d ) ) and len ( os . listdir ( os . path . join ( dir , d ) ) ) > 0 and " train.yaml " in os . listdir ( os . path . join ( dir , d ) ) ] )
2023-02-18 14:10:26 +00:00
2023-02-19 01:47:06 +00:00
def update_whisper_model ( name ) :
global whisper_model
if whisper_model :
del whisper_model
whisper_model = None
args . whisper_model = name
print ( f " Loading Whisper model: { args . whisper_model } " )
whisper_model = whisper . load_model ( args . whisper_model )
2023-02-18 14:51:00 +00:00
def update_autoregressive_model ( path_name ) :
2023-02-18 14:10:26 +00:00
global tts
if not tts :
raise Exception ( " TTS is uninitialized or still initializing... " )
print ( f " Loading model: { path_name } " )
2023-02-19 05:10:08 +00:00
if hasattr ( tts , ' load_autoregressive_model ' ) and tts . load_autoregressive_model ( path_name ) :
tts . load_autoregressive_model ( path_name )
# polyfill in case a user did NOT update the packages
else :
from tortoise . models . autoregressive import UnifiedVoice
tts . autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os . path . exists ( autoregressive_model_path ) else get_model_path ( ' autoregressive.pth ' , tts . models_dir )
del tts . autoregressive
tts . autoregressive = UnifiedVoice ( max_mel_tokens = 604 , max_text_tokens = 402 , max_conditioning_inputs = 2 , layers = 30 ,
model_dim = 1024 ,
heads = 16 , number_text_tokens = 255 , start_text_token = 255 , checkpointing = False ,
train_solo_embeddings = False ) . cpu ( ) . eval ( )
tts . autoregressive . load_state_dict ( torch . load ( tts . autoregressive_model_path ) )
tts . autoregressive . post_init_gpt2_config ( kv_cache = tts . use_kv_cache )
if tts . preloaded_tensors :
tts . autoregressive = tts . autoregressive . to ( tts . device )
2023-02-18 14:10:26 +00:00
print ( f " Loaded model: { tts . autoregressive_model_path } " )
2023-02-18 20:37:37 +00:00
args . autoregressive_model = path_name
save_args_settings ( )
2023-02-18 14:10:26 +00:00
return path_name
def update_args ( listen , share , check_for_updates , models_from_local_only , low_vram , embed_output_metadata , latents_lean_and_mean , voice_fixer , voice_fixer_use_cuda , force_cpu_for_conditioning_latents , defer_tts_load , device_override , sample_batch_size , concurrency_count , output_sample_rate , output_volume ) :
2023-02-17 03:05:27 +00:00
global args
args . listen = listen
args . share = share
args . check_for_updates = check_for_updates
args . models_from_local_only = models_from_local_only
args . low_vram = low_vram
args . force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
2023-02-17 20:43:12 +00:00
args . defer_tts_load = defer_tts_load
2023-02-17 03:05:27 +00:00
args . device_override = device_override
args . sample_batch_size = sample_batch_size
args . embed_output_metadata = embed_output_metadata
args . latents_lean_and_mean = latents_lean_and_mean
args . voice_fixer = voice_fixer
args . voice_fixer_use_cuda = voice_fixer_use_cuda
args . concurrency_count = concurrency_count
args . output_sample_rate = output_sample_rate
args . output_volume = output_volume
2023-02-18 14:10:26 +00:00
save_args_settings ( )
def save_args_settings ( ) :
2023-02-17 03:05:27 +00:00
settings = {
' listen ' : None if args . listen else args . listen ,
' share ' : args . share ,
' low-vram ' : args . low_vram ,
' check-for-updates ' : args . check_for_updates ,
' models-from-local-only ' : args . models_from_local_only ,
' force-cpu-for-conditioning-latents ' : args . force_cpu_for_conditioning_latents ,
2023-02-17 20:43:12 +00:00
' defer-tts-load ' : args . defer_tts_load ,
2023-02-17 03:05:27 +00:00
' device-override ' : args . device_override ,
2023-02-17 06:01:14 +00:00
' whisper-model ' : args . whisper_model ,
2023-02-18 14:10:26 +00:00
' autoregressive-model ' : args . autoregressive_model ,
2023-02-17 03:05:27 +00:00
' sample-batch-size ' : args . sample_batch_size ,
' embed-output-metadata ' : args . embed_output_metadata ,
' latents-lean-and-mean ' : args . latents_lean_and_mean ,
' voice-fixer ' : args . voice_fixer ,
' voice-fixer-use-cuda ' : args . voice_fixer_use_cuda ,
' concurrency-count ' : args . concurrency_count ,
' output-sample-rate ' : args . output_sample_rate ,
' output-volume ' : args . output_volume ,
}
2023-02-17 20:10:27 +00:00
os . makedirs ( ' ./config/ ' , exist_ok = True )
2023-02-17 03:05:27 +00:00
with open ( f ' ./config/exec.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( settings , indent = ' \t ' ) )
def read_generate_settings ( file , read_latents = True , read_json = True ) :
j = None
latents = None
if file is not None :
if hasattr ( file , ' name ' ) :
file = file . name
if file [ - 4 : ] == " .wav " :
metadata = music_tag . load_file ( file )
if ' lyrics ' in metadata :
j = json . loads ( str ( metadata [ ' lyrics ' ] ) )
elif file [ - 5 : ] == " .json " :
with open ( file , ' r ' ) as f :
j = json . load ( f )
if j is None :
print ( " No metadata found in audio file to read " )
else :
if ' latents ' in j :
if read_latents :
latents = base64 . b64decode ( j [ ' latents ' ] )
del j [ ' latents ' ]
if " time " in j :
j [ " time " ] = " {:.3f} " . format ( j [ " time " ] )
return (
j ,
latents ,
2023-02-18 02:07:22 +00:00
)
def enumerate_progress ( iterable , desc = None , progress = None , verbose = None ) :
if verbose and desc is not None :
print ( desc )
if progress is None :
return tqdm ( iterable , disable = not verbose )
return progress . tqdm ( iterable , desc = f ' { progress . msg_prefix } { desc } ' if hasattr ( progress , ' msg_prefix ' ) else desc , track_tqdm = True )
def notify_progress ( message , progress = None , verbose = True ) :
if verbose :
print ( message )
if progress is None :
return
progress ( 0 , desc = message )