2023-02-17 00:08:27 +00:00
import os
import argparse
import time
import json
import base64
import re
import urllib . request
import torch
import torchaudio
import music_tag
import gradio as gr
import gradio . utils
from datetime import datetime
import tortoise . api
2023-02-17 05:42:55 +00:00
from tortoise . utils . audio import get_voice_dir , get_voices
2023-03-04 20:42:54 +00:00
from tortoise . utils . device import get_device_count
2023-02-17 00:08:27 +00:00
from utils import *
args = setup_args ( )
2023-02-17 03:05:27 +00:00
def run_generation (
text ,
delimiter ,
emotion ,
prompt ,
voice ,
mic_audio ,
voice_latents_chunks ,
seed ,
candidates ,
num_autoregressive_samples ,
diffusion_iterations ,
temperature ,
diffusion_sampler ,
breathing_room ,
cvvp_weight ,
top_p ,
diffusion_temperature ,
length_penalty ,
repetition_penalty ,
cond_free_k ,
experimental_checkboxes ,
progress = gr . Progress ( track_tqdm = True )
) :
2023-02-23 13:18:51 +00:00
if not text :
raise gr . Error ( " Please provide text. " )
if not voice :
raise gr . Error ( " Please provide a voice. " )
2023-02-17 03:05:27 +00:00
try :
sample , outputs , stats = generate (
2023-02-19 16:16:44 +00:00
text = text ,
delimiter = delimiter ,
emotion = emotion ,
prompt = prompt ,
voice = voice ,
mic_audio = mic_audio ,
voice_latents_chunks = voice_latents_chunks ,
seed = seed ,
candidates = candidates ,
num_autoregressive_samples = num_autoregressive_samples ,
diffusion_iterations = diffusion_iterations ,
temperature = temperature ,
diffusion_sampler = diffusion_sampler ,
breathing_room = breathing_room ,
cvvp_weight = cvvp_weight ,
top_p = top_p ,
diffusion_temperature = diffusion_temperature ,
length_penalty = length_penalty ,
repetition_penalty = repetition_penalty ,
cond_free_k = cond_free_k ,
experimental_checkboxes = experimental_checkboxes ,
progress = progress
2023-02-17 03:05:27 +00:00
)
except Exception as e :
message = str ( e )
if message == " Kill signal detected " :
2023-02-24 23:13:13 +00:00
unload_tts ( )
2023-02-17 03:05:27 +00:00
raise gr . Error ( message )
return (
outputs [ 0 ] ,
gr . update ( value = sample , visible = sample is not None ) ,
gr . update ( choices = outputs , value = outputs [ 0 ] , visible = len ( outputs ) > 1 , interactive = True ) ,
gr . update ( value = stats , visible = True ) ,
)
2023-02-17 00:08:27 +00:00
def update_presets ( value ) :
PRESETS = {
' Ultra Fast ' : { ' num_autoregressive_samples ' : 16 , ' diffusion_iterations ' : 30 , ' cond_free ' : False } ,
' Fast ' : { ' num_autoregressive_samples ' : 96 , ' diffusion_iterations ' : 80 } ,
' Standard ' : { ' num_autoregressive_samples ' : 256 , ' diffusion_iterations ' : 200 } ,
' High Quality ' : { ' num_autoregressive_samples ' : 256 , ' diffusion_iterations ' : 400 } ,
}
if value in PRESETS :
preset = PRESETS [ value ]
return ( gr . update ( value = preset [ ' num_autoregressive_samples ' ] ) , gr . update ( value = preset [ ' diffusion_iterations ' ] ) )
else :
return ( gr . update ( ) , gr . update ( ) )
2023-02-17 19:06:05 +00:00
def get_training_configs ( ) :
configs = [ ]
for i , file in enumerate ( sorted ( os . listdir ( f " ./training/ " ) ) ) :
if file [ - 5 : ] != " .yaml " or file [ 0 ] == " . " :
continue
configs . append ( f " ./training/ { file } " )
return configs
def update_training_configs ( ) :
2023-02-18 14:51:00 +00:00
return gr . update ( choices = get_training_list ( ) )
2023-02-17 19:06:05 +00:00
2023-02-18 02:07:22 +00:00
history_headers = {
" Name " : " " ,
" Samples " : " num_autoregressive_samples " ,
" Iterations " : " diffusion_iterations " ,
" Temp. " : " temperature " ,
" Sampler " : " diffusion_sampler " ,
" CVVP " : " cvvp_weight " ,
" Top P " : " top_p " ,
" Diff. Temp. " : " diffusion_temperature " ,
" Len Pen " : " length_penalty " ,
" Rep Pen " : " repetition_penalty " ,
" Cond-Free K " : " cond_free_k " ,
" Time " : " time " ,
2023-02-28 15:36:06 +00:00
" Datetime " : " datetime " ,
" Model " : " model " ,
" Model Hash " : " model_hash " ,
2023-02-18 02:07:22 +00:00
}
2023-02-17 19:06:05 +00:00
def history_view_results ( voice ) :
results = [ ]
files = [ ]
outdir = f " ./results/ { voice } / "
for i , file in enumerate ( sorted ( os . listdir ( outdir ) ) ) :
if file [ - 4 : ] != " .wav " :
continue
metadata , _ = read_generate_settings ( f " { outdir } / { file } " , read_latents = False )
if metadata is None :
continue
values = [ ]
2023-02-18 02:07:22 +00:00
for k in history_headers :
2023-02-17 19:06:05 +00:00
v = file
if k != " Name " :
2023-02-28 15:36:06 +00:00
v = metadata [ history_headers [ k ] ] if history_headers [ k ] in metadata else ' ? '
2023-02-17 19:06:05 +00:00
values . append ( v )
files . append ( file )
results . append ( values )
return (
results ,
gr . Dropdown . update ( choices = sorted ( files ) )
)
2023-03-07 03:55:35 +00:00
def compute_latents_proxy ( voice , voice_latents_chunks , progress = gr . Progress ( track_tqdm = True ) ) :
compute_latents ( voice = voice , voice_latents_chunks = voice_latents_chunks , progress = progress )
return voice
2023-02-18 02:07:22 +00:00
def import_voices_proxy ( files , name , progress = gr . Progress ( track_tqdm = True ) ) :
import_voices ( files , name , progress )
return gr . update ( )
2023-02-17 19:06:05 +00:00
def read_generate_settings_proxy ( file , saveAs = ' .temp ' ) :
j , latents = read_generate_settings ( file )
if latents :
outdir = f ' { get_voice_dir ( ) } / { saveAs } / '
os . makedirs ( outdir , exist_ok = True )
with open ( f ' { outdir } /cond_latents.pth ' , ' wb ' ) as f :
f . write ( latents )
latents = f ' { outdir } /cond_latents.pth '
return (
2023-02-18 02:07:22 +00:00
gr . update ( value = j , visible = j is not None ) ,
2023-02-17 19:06:05 +00:00
gr . update ( value = latents , visible = latents is not None ) ,
2023-03-05 23:55:27 +00:00
None if j is None else j [ ' voice ' ] ,
gr . update ( visible = j is not None ) ,
2023-02-17 19:06:05 +00:00
)
2023-03-06 10:47:06 +00:00
def prepare_dataset_proxy ( voice , language , skip_existings , progress = gr . Progress ( track_tqdm = True ) ) :
return prepare_dataset ( get_voices ( load_latents = False ) [ voice ] , outdir = f " ./training/ { voice } / " , language = language , skip_existings = skip_existings , progress = progress )
2023-02-17 20:43:12 +00:00
2023-02-19 20:22:03 +00:00
def optimize_training_settings_proxy ( * args , * * kwargs ) :
tup = optimize_training_settings ( * args , * * kwargs )
return (
gr . update ( value = tup [ 0 ] ) ,
gr . update ( value = tup [ 1 ] ) ,
gr . update ( value = tup [ 2 ] ) ,
gr . update ( value = tup [ 3 ] ) ,
gr . update ( value = tup [ 4 ] ) ,
gr . update ( value = tup [ 5 ] ) ,
gr . update ( value = tup [ 6 ] ) ,
2023-03-01 01:17:38 +00:00
gr . update ( value = tup [ 7 ] ) ,
" \n " . join ( tup [ 8 ] )
2023-02-19 20:22:03 +00:00
)
2023-02-26 01:57:56 +00:00
def import_training_settings_proxy ( voice ) :
2023-02-23 23:22:23 +00:00
indir = f ' ./training/ { voice } / '
outdir = f ' ./training/ { voice } -finetune/ '
in_config_path = f " { indir } /train.yaml "
2023-02-28 22:13:21 +00:00
out_config_path = None
2023-02-26 01:57:56 +00:00
out_configs = [ ]
2023-02-24 16:23:30 +00:00
if os . path . isdir ( outdir ) :
out_configs = sorted ( [ d [ : - 5 ] for d in os . listdir ( outdir ) if d [ - 5 : ] == " .yaml " ] )
2023-02-23 23:22:23 +00:00
if len ( out_configs ) > 0 :
out_config_path = f ' { outdir } / { out_configs [ - 1 ] } .yaml '
config_path = out_config_path if out_config_path else in_config_path
messages = [ ]
with open ( config_path , ' r ' ) as file :
config = yaml . safe_load ( file )
messages . append ( f " Importing from: { config_path } " )
dataset_path = f " ./training/ { voice } /train.txt "
with open ( dataset_path , ' r ' , encoding = " utf-8 " ) as f :
lines = len ( f . readlines ( ) )
messages . append ( f " Basing epoch size to { lines } lines " )
batch_size = config [ ' datasets ' ] [ ' train ' ] [ ' batch_size ' ]
2023-03-04 15:55:06 +00:00
gradient_accumulation_size = config [ ' train ' ] [ ' mega_batch_factor ' ]
2023-03-01 01:17:38 +00:00
2023-02-23 23:22:23 +00:00
iterations = config [ ' train ' ] [ ' niter ' ]
steps_per_iteration = int ( lines / batch_size )
epochs = int ( iterations / steps_per_iteration )
learning_rate = config [ ' steps ' ] [ ' gpt_train ' ] [ ' optimizer_params ' ] [ ' lr ' ]
2023-03-01 01:17:38 +00:00
text_ce_lr_weight = config [ ' steps ' ] [ ' gpt_train ' ] [ ' losses ' ] [ ' text_ce ' ] [ ' weight ' ]
2023-02-23 23:22:23 +00:00
learning_rate_schedule = [ int ( x / steps_per_iteration ) for x in config [ ' train ' ] [ ' gen_lr_steps ' ] ]
print_rate = int ( config [ ' logger ' ] [ ' print_freq ' ] / steps_per_iteration )
save_rate = int ( config [ ' logger ' ] [ ' save_checkpoint_freq ' ] / steps_per_iteration )
statedir = f ' { outdir } /training_state/ ' # NOOO STOP MIXING YOUR CASES
2023-02-28 22:13:21 +00:00
resumes = [ ]
resume_path = None
2023-03-01 01:17:38 +00:00
source_model = None
if " pretrain_model_gpt " in config [ ' path ' ] :
source_model = config [ ' path ' ] [ ' pretrain_model_gpt ' ]
elif " resume_state " in config [ ' path ' ] :
resume_path = config [ ' path ' ] [ ' resume_state ' ]
2023-02-28 22:13:21 +00:00
if os . path . isdir ( statedir ) :
resumes = sorted ( [ int ( d [ : - 6 ] ) for d in os . listdir ( statedir ) if d [ - 6 : ] == " .state " ] )
2023-02-23 23:22:23 +00:00
if len ( resumes ) > 0 :
resume_path = f ' { statedir } / { resumes [ - 1 ] } .state '
messages . append ( f " Latest resume found: { resume_path } " )
2023-03-01 01:17:38 +00:00
2023-02-26 01:57:56 +00:00
half_p = config [ ' fp16 ' ]
bnb = True
if " ext " in config and " bitsandbytes " in config [ " ext " ] :
bnb = config [ " ext " ] [ " bitsandbytes " ]
2023-03-05 05:17:19 +00:00
workers = config [ ' datasets ' ] [ ' train ' ] [ ' n_workers ' ]
2023-02-23 23:22:23 +00:00
messages = " \n " . join ( messages )
return (
epochs ,
learning_rate ,
2023-03-01 01:17:38 +00:00
text_ce_lr_weight ,
2023-02-23 23:22:23 +00:00
learning_rate_schedule ,
batch_size ,
2023-03-04 15:55:06 +00:00
gradient_accumulation_size ,
2023-02-23 23:22:23 +00:00
print_rate ,
save_rate ,
resume_path ,
2023-02-26 01:57:56 +00:00
half_p ,
bnb ,
2023-03-05 05:17:19 +00:00
workers ,
2023-03-01 01:17:38 +00:00
source_model ,
2023-02-23 23:22:23 +00:00
messages
)
2023-03-05 05:17:19 +00:00
def save_training_settings_proxy ( epochs , learning_rate , text_ce_lr_weight , learning_rate_schedule , batch_size , gradient_accumulation_size , print_rate , save_rate , resume_path , half_p , bnb , workers , source_model , voice ) :
2023-02-18 14:51:00 +00:00
name = f " { voice } -finetune "
dataset_name = f " { voice } -train "
dataset_path = f " ./training/ { voice } /train.txt "
validation_name = f " { voice } -val "
2023-03-07 20:38:31 +00:00
validation_path = f " ./training/ { voice } /validation.txt "
2023-02-18 14:51:00 +00:00
with open ( dataset_path , ' r ' , encoding = " utf-8 " ) as f :
lines = len ( f . readlines ( ) )
2023-02-19 16:16:44 +00:00
messages = [ ]
2023-02-19 20:22:03 +00:00
iterations = calc_iterations ( epochs = epochs , lines = lines , batch_size = batch_size )
messages . append ( f " For { epochs } epochs with { lines } lines, iterating for { iterations } steps " )
2023-02-18 14:51:00 +00:00
2023-02-19 20:38:00 +00:00
print_rate = int ( print_rate * iterations / epochs )
save_rate = int ( save_rate * iterations / epochs )
2023-03-07 20:38:31 +00:00
validation_rate = save_rate
if iterations % save_rate != 0 :
adjustment = int ( iterations / save_rate ) * save_rate
messages . append ( f " Iteration rate is not evenly divisible by save rate, adjusting: { iterations } => { adjustment } " )
iterations = adjustment
if not os . path . exists ( validation_path ) :
validation_rate = iterations
validation_path = dataset_path
2023-02-19 20:38:00 +00:00
2023-02-19 21:06:14 +00:00
if not learning_rate_schedule :
learning_rate_schedule = EPOCH_SCHEDULE
2023-03-04 04:41:56 +00:00
elif isinstance ( learning_rate_schedule , str ) :
learning_rate_schedule = json . loads ( learning_rate_schedule )
learning_rate_schedule = schedule_learning_rate ( iterations / epochs , learning_rate_schedule )
2023-02-19 21:06:14 +00:00
2023-02-19 20:22:03 +00:00
messages . append ( save_training_settings (
iterations = iterations ,
2023-02-19 16:16:44 +00:00
batch_size = batch_size ,
learning_rate = learning_rate ,
2023-03-01 01:17:38 +00:00
text_ce_lr_weight = text_ce_lr_weight ,
2023-02-19 16:16:44 +00:00
learning_rate_schedule = learning_rate_schedule ,
2023-03-04 15:55:06 +00:00
gradient_accumulation_size = gradient_accumulation_size ,
2023-02-19 16:16:44 +00:00
print_rate = print_rate ,
save_rate = save_rate ,
2023-03-07 20:38:31 +00:00
validation_rate = validation_rate ,
2023-02-19 16:16:44 +00:00
name = name ,
dataset_name = dataset_name ,
dataset_path = dataset_path ,
validation_name = validation_name ,
validation_path = validation_path ,
output_name = f " { voice } /train.yaml " ,
resume_path = resume_path ,
2023-02-21 19:31:57 +00:00
half_p = half_p ,
2023-02-26 01:57:56 +00:00
bnb = bnb ,
2023-03-05 05:17:19 +00:00
workers = workers ,
2023-03-01 01:17:38 +00:00
source_model = source_model ,
2023-02-19 16:16:44 +00:00
) )
return " \n " . join ( messages )
2023-02-18 14:51:00 +00:00
2023-02-17 19:06:05 +00:00
def update_voices ( ) :
return (
2023-02-21 21:50:05 +00:00
gr . Dropdown . update ( choices = get_voice_list ( append_defaults = True ) ) ,
2023-02-17 19:06:05 +00:00
gr . Dropdown . update ( choices = get_voice_list ( ) ) ,
gr . Dropdown . update ( choices = get_voice_list ( " ./results/ " ) ) ,
)
def history_copy_settings ( voice , file ) :
return import_generate_settings ( f " ./results/ { voice } / { file } " )
2023-02-17 00:08:27 +00:00
def setup_gradio ( ) :
global args
global ui
if not args . share :
def noop ( function , return_value = None ) :
def wrapped ( * args , * * kwargs ) :
return return_value
return wrapped
gradio . utils . version_check = noop ( gradio . utils . version_check )
gradio . utils . initiated_analytics = noop ( gradio . utils . initiated_analytics )
gradio . utils . launch_analytics = noop ( gradio . utils . launch_analytics )
gradio . utils . integration_analytics = noop ( gradio . utils . integration_analytics )
gradio . utils . error_analytics = noop ( gradio . utils . error_analytics )
gradio . utils . log_feature_analytics = noop ( gradio . utils . log_feature_analytics )
#gradio.utils.get_local_ip_address = noop(gradio.utils.get_local_ip_address, 'localhost')
if args . models_from_local_only :
os . environ [ ' TRANSFORMERS_OFFLINE ' ] = ' 1 '
2023-03-01 01:17:38 +00:00
voice_list_with_defaults = get_voice_list ( append_defaults = True )
voice_list = get_voice_list ( )
result_voices = get_voice_list ( " ./results/ " )
autoregressive_models = get_autoregressive_models ( )
dataset_list = get_dataset_list ( )
2023-02-17 00:08:27 +00:00
with gr . Blocks ( ) as ui :
with gr . Tab ( " Generate " ) :
with gr . Row ( ) :
with gr . Column ( ) :
2023-03-05 23:55:27 +00:00
text = gr . Textbox ( lines = 4 , label = " Input Prompt " )
2023-02-17 00:08:27 +00:00
with gr . Row ( ) :
with gr . Column ( ) :
delimiter = gr . Textbox ( lines = 1 , label = " Line Delimiter " , placeholder = " \\ n " )
2023-03-05 23:55:27 +00:00
emotion = gr . Radio ( [ " Happy " , " Sad " , " Angry " , " Disgusted " , " Arrogant " , " Custom " , " None " ] , value = " None " , label = " Emotion " , type = " value " , interactive = True )
prompt = gr . Textbox ( lines = 1 , label = " Custom Emotion " )
2023-03-01 01:17:38 +00:00
voice = gr . Dropdown ( choices = voice_list_with_defaults , label = " Voice " , type = " value " , value = voice_list_with_defaults [ 0 ] ) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
2023-03-05 23:55:27 +00:00
mic_audio = gr . Audio ( label = " Microphone Source " , source = " microphone " , type = " filepath " , visible = False )
2023-03-07 03:55:35 +00:00
voice_latents_chunks = gr . Number ( label = " Voice Chunks " , precision = 0 , value = 0 )
2023-02-24 12:58:41 +00:00
with gr . Row ( ) :
refresh_voices = gr . Button ( value = " Refresh Voice List " )
recompute_voice_latents = gr . Button ( value = " (Re)Compute Voice Latents " )
2023-02-22 03:31:46 +00:00
voice . change (
fn = update_baseline_for_latents_chunks ,
inputs = voice ,
outputs = voice_latents_chunks
)
2023-03-05 23:55:27 +00:00
voice . change (
fn = lambda value : gr . update ( visible = value == " microphone " ) ,
inputs = voice ,
outputs = mic_audio ,
)
2023-02-17 00:08:27 +00:00
with gr . Column ( ) :
candidates = gr . Slider ( value = 1 , minimum = 1 , maximum = 6 , step = 1 , label = " Candidates " )
seed = gr . Number ( value = 0 , precision = 0 , label = " Seed " )
2023-02-18 02:07:22 +00:00
preset = gr . Radio ( [ " Ultra Fast " , " Fast " , " Standard " , " High Quality " ] , label = " Preset " , type = " value " )
2023-02-21 22:13:30 +00:00
num_autoregressive_samples = gr . Slider ( value = 128 , minimum = 2 , maximum = 512 , step = 1 , label = " Samples " )
2023-02-17 00:08:27 +00:00
diffusion_iterations = gr . Slider ( value = 128 , minimum = 0 , maximum = 512 , step = 1 , label = " Iterations " )
temperature = gr . Slider ( value = 0.2 , minimum = 0 , maximum = 1 , step = 0.1 , label = " Temperature " )
show_experimental_settings = gr . Checkbox ( label = " Show Experimental Settings " )
reset_generation_settings_button = gr . Button ( value = " Reset to Default " )
with gr . Column ( visible = False ) as col :
experimental_column = col
experimental_checkboxes = gr . CheckboxGroup ( [ " Half Precision " , " Conditioning-Free " ] , value = [ " Conditioning-Free " ] , label = " Experimental Flags " )
2023-03-05 23:55:27 +00:00
breathing_room = gr . Slider ( value = 8 , minimum = 1 , maximum = 32 , step = 1 , label = " Pause Size " )
diffusion_sampler = gr . Radio (
[ " P " , " DDIM " ] , # + ["K_Euler_A", "DPM++2M"],
value = " DDIM " , label = " Diffusion Samplers " , type = " value "
)
2023-02-17 00:08:27 +00:00
cvvp_weight = gr . Slider ( value = 0 , minimum = 0 , maximum = 1 , label = " CVVP Weight " )
top_p = gr . Slider ( value = 0.8 , minimum = 0 , maximum = 1 , label = " Top P " )
diffusion_temperature = gr . Slider ( value = 1.0 , minimum = 0 , maximum = 1 , label = " Diffusion Temperature " )
length_penalty = gr . Slider ( value = 1.0 , minimum = 0 , maximum = 8 , label = " Length Penalty " )
repetition_penalty = gr . Slider ( value = 2.0 , minimum = 0 , maximum = 8 , label = " Repetition Penalty " )
cond_free_k = gr . Slider ( value = 2.0 , minimum = 0 , maximum = 4 , label = " Conditioning-Free K " )
with gr . Column ( ) :
2023-02-23 23:22:23 +00:00
with gr . Row ( ) :
submit = gr . Button ( value = " Generate " )
stop = gr . Button ( value = " Stop " )
2023-02-17 00:08:27 +00:00
generation_results = gr . Dataframe ( label = " Results " , headers = [ " Seed " , " Time " ] , visible = False )
source_sample = gr . Audio ( label = " Source Sample " , visible = False )
output_audio = gr . Audio ( label = " Output " )
2023-02-21 03:00:45 +00:00
candidates_list = gr . Dropdown ( label = " Candidates " , type = " value " , visible = False , choices = [ " " ] , value = " " )
def change_candidate ( val ) :
if not val :
return
return val
candidates_list . change (
fn = change_candidate ,
inputs = candidates_list ,
outputs = output_audio ,
)
2023-02-17 00:08:27 +00:00
with gr . Tab ( " History " ) :
with gr . Row ( ) :
with gr . Column ( ) :
2023-02-18 02:07:22 +00:00
history_info = gr . Dataframe ( label = " Results " , headers = list ( history_headers . keys ( ) ) )
2023-02-17 00:08:27 +00:00
with gr . Row ( ) :
with gr . Column ( ) :
2023-02-22 03:27:28 +00:00
history_voices = gr . Dropdown ( choices = result_voices , label = " Voice " , type = " value " , value = result_voices [ 0 ] if len ( result_voices ) > 0 else " " )
2023-02-17 00:08:27 +00:00
with gr . Column ( ) :
2023-02-21 03:00:45 +00:00
history_results_list = gr . Dropdown ( label = " Results " , type = " value " , interactive = True , value = " " )
2023-02-17 00:08:27 +00:00
with gr . Column ( ) :
history_audio = gr . Audio ( )
history_copy_settings_button = gr . Button ( value = " Copy Settings " )
with gr . Tab ( " Utilities " ) :
with gr . Row ( ) :
with gr . Column ( ) :
2023-02-18 02:07:22 +00:00
audio_in = gr . Files ( type = " file " , label = " Audio Input " , file_types = [ " audio " ] )
2023-02-17 00:08:27 +00:00
import_voice_name = gr . Textbox ( label = " Voice Name " )
import_voice_button = gr . Button ( value = " Import Voice " )
2023-03-05 23:55:27 +00:00
with gr . Column ( visible = False ) as col :
utilities_metadata_column = col
metadata_out = gr . JSON ( label = " Audio Metadata " )
copy_button = gr . Button ( value = " Copy Settings " )
latents_out = gr . File ( type = " binary " , label = " Voice Latents " )
2023-02-17 03:05:27 +00:00
with gr . Tab ( " Training " ) :
2023-02-17 06:01:14 +00:00
with gr . Tab ( " Prepare Dataset " ) :
2023-02-17 03:05:27 +00:00
with gr . Row ( ) :
2023-02-17 05:42:55 +00:00
with gr . Column ( ) :
dataset_settings = [
2023-03-01 01:17:38 +00:00
gr . Dropdown ( choices = voice_list , label = " Dataset Source " , type = " value " , value = voice_list [ 0 ] if len ( voice_list ) > 0 else " " ) ,
2023-03-06 10:47:06 +00:00
gr . Textbox ( label = " Language " , value = " en " ) ,
gr . Checkbox ( label = " Skip Already Transcribed " , value = False )
2023-02-17 05:42:55 +00:00
]
prepare_dataset_button = gr . Button ( value = " Prepare " )
2023-02-18 02:07:22 +00:00
with gr . Column ( ) :
prepare_dataset_output = gr . TextArea ( label = " Console Output " , interactive = False , max_lines = 8 )
2023-02-17 06:01:14 +00:00
with gr . Tab ( " Generate Configuration " ) :
with gr . Row ( ) :
2023-02-17 03:05:27 +00:00
with gr . Column ( ) :
training_settings = [
2023-02-19 21:06:14 +00:00
gr . Number ( label = " Epochs " , value = 500 , precision = 0 ) ,
2023-02-23 23:22:23 +00:00
]
with gr . Row ( ) :
2023-03-01 01:17:38 +00:00
with gr . Column ( ) :
training_settings = training_settings + [
gr . Slider ( label = " Learning Rate " , value = 1e-5 , minimum = 0 , maximum = 1e-4 , step = 1e-6 ) ,
gr . Slider ( label = " Text_CE LR Ratio " , value = 0.01 , minimum = 0 , maximum = 1 ) ,
]
2023-02-23 23:22:23 +00:00
training_settings = training_settings + [
gr . Textbox ( label = " Learning Rate Schedule " , placeholder = str ( EPOCH_SCHEDULE ) ) ,
]
with gr . Row ( ) :
training_settings = training_settings + [
gr . Number ( label = " Batch Size " , value = 128 , precision = 0 ) ,
2023-03-04 15:55:06 +00:00
gr . Number ( label = " Gradient Accumulation Size " , value = 4 , precision = 0 ) ,
2023-02-23 23:22:23 +00:00
]
with gr . Row ( ) :
training_settings = training_settings + [
gr . Number ( label = " Print Frequency (in epochs) " , value = 5 , precision = 0 ) ,
gr . Number ( label = " Save Frequency (in epochs) " , value = 5 , precision = 0 ) ,
]
training_settings = training_settings + [
2023-02-19 16:16:44 +00:00
gr . Textbox ( label = " Resume State Path " , placeholder = " ./training/$ {voice} -finetune/training_state/$ {last_state} .state " ) ,
2023-02-17 03:05:27 +00:00
]
2023-03-05 05:17:19 +00:00
with gr . Row ( ) :
training_halfp = gr . Checkbox ( label = " Half Precision " , value = args . training_default_halfp )
training_bnb = gr . Checkbox ( label = " BitsAndBytes " , value = args . training_default_bnb )
training_workers = gr . Number ( label = " Worker Processes " , value = 2 , precision = 0 )
2023-03-01 01:17:38 +00:00
source_model = gr . Dropdown ( choices = autoregressive_models , label = " Source Model " , type = " value " , value = autoregressive_models [ 0 ] )
dataset_list_dropdown = gr . Dropdown ( choices = dataset_list , label = " Dataset " , type = " value " , value = dataset_list [ 0 ] if len ( dataset_list ) else " " )
2023-03-05 05:17:19 +00:00
training_settings = training_settings + [ training_halfp , training_bnb , training_workers , source_model , dataset_list_dropdown ]
2023-02-26 01:57:56 +00:00
2023-02-23 23:22:23 +00:00
with gr . Row ( ) :
refresh_dataset_list = gr . Button ( value = " Refresh Dataset List " )
2023-02-28 22:13:21 +00:00
import_dataset_button = gr . Button ( value = " Reuse/Import Dataset " )
2023-02-18 02:07:22 +00:00
with gr . Column ( ) :
save_yaml_output = gr . TextArea ( label = " Console Output " , interactive = False , max_lines = 8 )
2023-02-23 23:22:23 +00:00
with gr . Row ( ) :
optimize_yaml_button = gr . Button ( value = " Validate Training Configuration " )
save_yaml_button = gr . Button ( value = " Save Training Configuration " )
2023-02-18 02:07:22 +00:00
with gr . Tab ( " Run Training " ) :
2023-02-17 16:29:27 +00:00
with gr . Row ( ) :
with gr . Column ( ) :
2023-02-18 14:51:00 +00:00
training_configs = gr . Dropdown ( label = " Training Configuration " , choices = get_training_list ( ) )
2023-02-23 23:22:23 +00:00
with gr . Row ( ) :
2023-03-02 01:35:12 +00:00
refresh_configs = gr . Button ( value = " Refresh Configurations " )
2023-03-01 19:32:11 +00:00
training_loss_graph = gr . LinePlot ( label = " Training Metrics " ,
x = " step " ,
y = " value " ,
title = " Training Metrics " ,
2023-02-28 06:18:18 +00:00
color = " type " ,
2023-03-01 19:32:11 +00:00
tooltip = [ ' step ' , ' value ' , ' type ' ] ,
2023-02-28 01:01:50 +00:00
width = 600 ,
2023-03-02 01:35:12 +00:00
height = 350 ,
2023-02-28 01:01:50 +00:00
)
2023-03-02 01:35:12 +00:00
view_losses = gr . Button ( value = " View Losses " )
with gr . Column ( ) :
training_output = gr . TextArea ( label = " Console Output " , interactive = False , max_lines = 8 )
verbose_training = gr . Checkbox ( label = " Verbose Console Output " , value = True )
2023-03-04 15:55:06 +00:00
with gr . Row ( ) :
2023-03-07 20:38:31 +00:00
training_keep_x_past_datasets = gr . Slider ( label = " Keep X Previous States " , minimum = 0 , maximum = 8 , value = 0 , step = 1 )
2023-03-04 20:42:54 +00:00
training_gpu_count = gr . Number ( label = " GPUs " , value = get_device_count ( ) )
2023-03-02 01:35:12 +00:00
with gr . Row ( ) :
start_training_button = gr . Button ( value = " Train " )
stop_training_button = gr . Button ( value = " Stop " )
reconnect_training_button = gr . Button ( value = " Reconnect " )
2023-02-17 00:08:27 +00:00
with gr . Tab ( " Settings " ) :
with gr . Row ( ) :
exec_inputs = [ ]
with gr . Column ( ) :
exec_inputs = exec_inputs + [
gr . Textbox ( label = " Listen " , value = args . listen , placeholder = " 127.0.0.1:7860/ " ) ,
gr . Checkbox ( label = " Public Share Gradio " , value = args . share ) ,
gr . Checkbox ( label = " Check For Updates " , value = args . check_for_updates ) ,
gr . Checkbox ( label = " Only Load Models Locally " , value = args . models_from_local_only ) ,
gr . Checkbox ( label = " Low VRAM " , value = args . low_vram ) ,
gr . Checkbox ( label = " Embed Output Metadata " , value = args . embed_output_metadata ) ,
gr . Checkbox ( label = " Slimmer Computed Latents " , value = args . latents_lean_and_mean ) ,
2023-02-21 03:00:45 +00:00
gr . Checkbox ( label = " Use Voice Fixer on Generated Output " , value = args . voice_fixer ) ,
2023-02-17 00:08:27 +00:00
gr . Checkbox ( label = " Use CUDA for Voice Fixer " , value = args . voice_fixer_use_cuda ) ,
gr . Checkbox ( label = " Force CPU for Conditioning Latents " , value = args . force_cpu_for_conditioning_latents ) ,
2023-02-21 03:00:45 +00:00
gr . Checkbox ( label = " Do Not Load TTS On Startup " , value = args . defer_tts_load ) ,
2023-02-21 21:50:05 +00:00
gr . Checkbox ( label = " Delete Non-Final Output " , value = args . prune_nonfinal_outputs ) ,
2023-02-17 00:08:27 +00:00
gr . Textbox ( label = " Device Override " , value = args . device_override ) ,
]
with gr . Column ( ) :
exec_inputs = exec_inputs + [
gr . Number ( label = " Sample Batch Size " , precision = 0 , value = args . sample_batch_size ) ,
2023-02-28 06:18:18 +00:00
gr . Number ( label = " Gradio Concurrency Count " , precision = 0 , value = args . concurrency_count ) ,
2023-03-03 21:13:48 +00:00
gr . Number ( label = " Auto-Calculate Voice Chunk Duration (in seconds) " , precision = 0 , value = args . autocalculate_voice_chunk_duration_size ) ,
2023-02-21 21:50:05 +00:00
gr . Slider ( label = " Output Volume " , minimum = 0 , maximum = 2 , value = args . output_volume ) ,
2023-02-17 00:08:27 +00:00
]
2023-02-18 14:10:26 +00:00
2023-02-21 03:00:45 +00:00
autoregressive_model_dropdown = gr . Dropdown ( choices = autoregressive_models , label = " Autoregressive Model " , value = args . autoregressive_model if args . autoregressive_model else autoregressive_models [ 0 ] )
2023-02-27 19:20:06 +00:00
2023-03-07 02:45:22 +00:00
vocoder_models = gr . Dropdown ( VOCODERS , label = " Vocoder " , value = args . vocoder_model if args . vocoder_model else VOCODERS [ - 1 ] )
2023-03-06 05:21:33 +00:00
whisper_backend = gr . Dropdown ( WHISPER_BACKENDS , label = " Whisper Backends " , value = args . whisper_backend )
2023-03-05 05:17:19 +00:00
whisper_model_dropdown = gr . Dropdown ( WHISPER_MODELS , label = " Whisper Model " , value = args . whisper_model )
2023-02-27 19:20:06 +00:00
2023-03-07 02:45:22 +00:00
exec_inputs = exec_inputs + [ autoregressive_model_dropdown , vocoder_models , whisper_backend , whisper_model_dropdown , training_halfp , training_bnb ]
2023-02-27 19:20:06 +00:00
2023-02-24 12:58:41 +00:00
with gr . Row ( ) :
autoregressive_models_update_button = gr . Button ( value = " Refresh Model List " )
gr . Button ( value = " Check for Updates " ) . click ( check_for_updates )
gr . Button ( value = " (Re)Load TTS " ) . click (
reload_tts ,
inputs = autoregressive_model_dropdown ,
outputs = None
)
2023-03-03 21:13:48 +00:00
# kill_button = gr.Button(value="Close UI")
2023-02-24 12:58:41 +00:00
def update_model_list_proxy ( val ) :
autoregressive_models = get_autoregressive_models ( )
if val not in autoregressive_models :
val = autoregressive_models [ 0 ]
return gr . update ( choices = autoregressive_models , value = val )
autoregressive_models_update_button . click (
update_model_list_proxy ,
inputs = autoregressive_model_dropdown ,
outputs = autoregressive_model_dropdown ,
)
2023-02-21 03:00:45 +00:00
2023-02-17 00:08:27 +00:00
for i in exec_inputs :
2023-02-18 14:10:26 +00:00
i . change ( fn = update_args , inputs = exec_inputs )
2023-02-27 19:20:06 +00:00
autoregressive_model_dropdown . change (
fn = update_autoregressive_model ,
inputs = autoregressive_model_dropdown ,
outputs = None
)
2023-02-17 00:08:27 +00:00
2023-03-07 02:45:22 +00:00
vocoder_models . change (
fn = update_vocoder_model ,
inputs = vocoder_models ,
outputs = None
)
2023-02-17 00:08:27 +00:00
input_settings = [
text ,
delimiter ,
emotion ,
prompt ,
voice ,
mic_audio ,
voice_latents_chunks ,
seed ,
candidates ,
num_autoregressive_samples ,
diffusion_iterations ,
temperature ,
diffusion_sampler ,
breathing_room ,
cvvp_weight ,
top_p ,
diffusion_temperature ,
length_penalty ,
repetition_penalty ,
cond_free_k ,
experimental_checkboxes ,
]
2023-02-21 03:00:45 +00:00
history_voices . change (
2023-02-18 02:07:22 +00:00
fn = history_view_results ,
inputs = history_voices ,
outputs = [
history_info ,
history_results_list ,
]
)
2023-02-21 03:00:45 +00:00
history_results_list . change (
2023-02-18 02:07:22 +00:00
fn = lambda voice , file : f " ./results/ { voice } / { file } " ,
inputs = [
history_voices ,
history_results_list ,
] ,
outputs = history_audio
)
audio_in . upload (
fn = read_generate_settings_proxy ,
inputs = audio_in ,
outputs = [
metadata_out ,
latents_out ,
2023-03-05 23:55:27 +00:00
import_voice_name ,
utilities_metadata_column ,
2023-02-18 02:07:22 +00:00
]
)
import_voice_button . click (
fn = import_voices_proxy ,
inputs = [
audio_in ,
import_voice_name ,
] ,
outputs = import_voice_name #console_output
)
show_experimental_settings . change (
fn = lambda x : gr . update ( visible = x ) ,
inputs = show_experimental_settings ,
outputs = experimental_column
)
preset . change ( fn = update_presets ,
inputs = preset ,
outputs = [
num_autoregressive_samples ,
diffusion_iterations ,
] ,
)
2023-03-07 03:55:35 +00:00
recompute_voice_latents . click ( compute_latents_proxy ,
2023-02-18 02:07:22 +00:00
inputs = [
voice ,
voice_latents_chunks ,
] ,
outputs = voice ,
)
2023-03-05 23:55:27 +00:00
emotion . change (
fn = lambda value : gr . update ( visible = value == " Custom " ) ,
inputs = emotion ,
outputs = prompt
2023-02-18 02:07:22 +00:00
)
mic_audio . change ( fn = lambda value : gr . update ( value = " microphone " ) ,
inputs = mic_audio ,
outputs = voice
)
2023-02-17 00:08:27 +00:00
refresh_voices . click ( update_voices ,
inputs = None ,
outputs = [
voice ,
2023-02-18 02:07:22 +00:00
dataset_settings [ 0 ] ,
2023-02-17 00:08:27 +00:00
history_voices
]
)
submit . click (
2023-02-21 03:00:45 +00:00
lambda : ( gr . update ( visible = False ) , gr . update ( visible = False ) , gr . update ( visible = False ) ) ,
outputs = [ source_sample , candidates_list , generation_results ] ,
2023-02-17 00:08:27 +00:00
)
submit_event = submit . click ( run_generation ,
inputs = input_settings ,
2023-02-21 03:00:45 +00:00
outputs = [ output_audio , source_sample , candidates_list , generation_results ] ,
2023-03-06 05:21:33 +00:00
api_name = " generate " ,
2023-02-17 00:08:27 +00:00
)
copy_button . click ( import_generate_settings ,
inputs = audio_in , # JSON elements cannot be used as inputs
outputs = input_settings
)
reset_generation_settings_button . click (
fn = reset_generation_settings ,
inputs = None ,
outputs = input_settings
)
history_copy_settings_button . click ( history_copy_settings ,
inputs = [
history_voices ,
history_results_list ,
] ,
outputs = input_settings
)
2023-02-18 14:51:00 +00:00
refresh_configs . click (
lambda : gr . update ( choices = get_training_list ( ) ) ,
inputs = None ,
outputs = training_configs
)
2023-02-18 02:07:22 +00:00
start_training_button . click ( run_training ,
2023-02-19 05:05:30 +00:00
inputs = [
training_configs ,
verbose_training ,
2023-03-03 04:37:18 +00:00
training_gpu_count ,
2023-03-07 20:38:31 +00:00
training_keep_x_past_datasets ,
2023-02-19 05:05:30 +00:00
] ,
2023-02-28 01:01:50 +00:00
outputs = [
training_output ,
] ,
)
training_output . change (
fn = update_training_dataplot ,
inputs = None ,
outputs = [
training_loss_graph ,
] ,
show_progress = False ,
2023-02-18 02:07:22 +00:00
)
2023-03-02 01:35:12 +00:00
view_losses . click (
fn = update_training_dataplot ,
inputs = [
training_configs
] ,
outputs = [
training_loss_graph ,
] ,
)
2023-02-18 02:07:22 +00:00
stop_training_button . click ( stop_training ,
inputs = None ,
outputs = training_output #console_output
)
2023-02-23 06:24:54 +00:00
reconnect_training_button . click ( reconnect_training ,
inputs = [
verbose_training ,
] ,
outputs = training_output #console_output
)
2023-02-18 02:07:22 +00:00
prepare_dataset_button . click (
prepare_dataset_proxy ,
inputs = dataset_settings ,
outputs = prepare_dataset_output #console_output
)
2023-02-18 14:51:00 +00:00
refresh_dataset_list . click (
lambda : gr . update ( choices = get_dataset_list ( ) ) ,
inputs = None ,
2023-03-01 01:17:38 +00:00
outputs = dataset_list_dropdown ,
2023-02-18 14:51:00 +00:00
)
2023-02-19 20:22:03 +00:00
optimize_yaml_button . click ( optimize_training_settings_proxy ,
inputs = training_settings ,
2023-03-01 01:17:38 +00:00
outputs = training_settings [ 1 : 9 ] + [ save_yaml_output ] #console_output
2023-02-19 20:22:03 +00:00
)
2023-02-23 23:22:23 +00:00
import_dataset_button . click ( import_training_settings_proxy ,
2023-03-01 01:17:38 +00:00
inputs = dataset_list_dropdown ,
2023-03-05 05:17:19 +00:00
outputs = training_settings [ : 13 ] + [ save_yaml_output ] #console_output
2023-02-23 23:22:23 +00:00
)
2023-02-18 14:51:00 +00:00
save_yaml_button . click ( save_training_settings_proxy ,
2023-02-18 02:07:22 +00:00
inputs = training_settings ,
outputs = save_yaml_output #console_output
)
2023-03-03 21:13:48 +00:00
"""
def kill_process ( ) :
ui . close ( )
exit ( )
kill_button . click (
kill_process ,
inputs = None ,
outputs = None
)
"""
2023-02-17 00:08:27 +00:00
if os . path . isfile ( ' ./config/generate.json ' ) :
ui . load ( import_generate_settings , inputs = None , outputs = input_settings )
if args . check_for_updates :
ui . load ( check_for_updates )
2023-02-24 23:13:13 +00:00
stop . click ( fn = cancel_generate , inputs = None , outputs = None )
2023-02-17 00:08:27 +00:00
ui . queue ( concurrency_count = args . concurrency_count )
webui = ui
return webui