@ -37,6 +37,8 @@ from tortoise.utils.text import split_and_recombine_text
from tortoise . utils . device import get_device_name , set_device_name
MODELS [ ' dvae.pth ' ] = " https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth "
WHISPER_MODELS = [ " tiny " , " base " , " small " , " medium " , " large " ]
WHISPER_SPECIALIZED_MODELS = [ " tiny.en " , " base.en " , " small.en " , " medium.en " ]
args = None
tts = None
@ -663,6 +665,7 @@ class TrainingState():
# rip out iteration info
if not self . training_started :
if line . find ( ' Start training from epoch ' ) > = 0 :
self . it_time_start = time . time ( )
self . epoch_time_start = time . time ( )
self . training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
should_return = True
@ -703,11 +706,12 @@ class TrainingState():
self . it_time_delta = self . it_time_end - self . it_time_start
self . it_time_start = time . time ( )
self . it_taken = self . it_taken + 1
try :
rate = f ' { " {:.3f} " . format ( self . it_time_delta ) } s/it ' if self . it_time_delta > = 1 else f ' { " {:.3f} " . format ( 1 / self . it_time_delta ) } it/s '
self . it_rate = rate
except Exception as e :
pass
if self . it_time_delta :
try :
rate = f ' { " {:.3f} " . format ( self . it_time_delta ) } s/it ' if self . it_time_delta > = 1 else f ' { " {:.3f} " . format ( 1 / self . it_time_delta ) } it/s '
self . it_rate = rate
except Exception as e :
pass
metric_step = [ f " { self . epoch } / { self . epochs } " , f " { self . it } / { self . its } " , f " { step } / { steps } " ]
metric_step = " , " . join ( metric_step )
@ -733,9 +737,23 @@ class TrainingState():
metric_loss = [ ]
if len ( self . losses ) > 0 :
metric_loss . append ( f ' Loss: { " {:3f} " . format ( self . losses [ - 1 ] [ " value " ] ) } ' )
if len ( self . losses ) > = 2 :
delta_loss = self . losses [ - 2 ] [ " value " ] - self . losses [ - 1 ] [ " value " ]
delta_step = self . losses [ - 2 ] [ " step " ] - self . losses [ - 1 ] [ " step " ]
inst_deriv = delta_loss / delta_step
est_loss = delta_loss + ( self . its - self . it ) * inst_deriv
metric_loss . append ( f ' Est. Final Loss: { " {:3f} " . format ( est_loss ) } ' )
print ( delta_loss , delta_step , inst_deriv , est_loss )
metric_loss = " , " . join ( metric_loss )
message = f ' [ { metric_step } ] [ { metric_rate } ] [ { metric_loss } ] [ETA: { eta_hhmmss } ] '
message = f ' [ { metric_step } ] [ { metric_rate } ] [ETA: { eta_hhmmss } ] [ { metric_loss } ] '
if lapsed :
self . epoch = self . epoch + 1
@ -764,6 +782,13 @@ class TrainingState():
self . buffer . append ( f ' [ { " {:.3f} " . format ( percent * 100 ) } %] { message } ' )
if line . find ( ' INFO: [epoch: ' ) > = 0 :
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ' : nan ' in line :
should_return = True
print ( " ! NAN DETECTED ! " )
self . buffer . append ( " ! NAN DETECTED ! " )
# easily rip out our stats...
match = re . findall ( r ' \ b([a-z_0-9]+?) \ b: +?([0-9] \ .[0-9]+?e[+-] \ d+|[ \ d,]+) \ b ' , line )
if match and len ( match ) > 0 :
@ -824,13 +849,13 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
if result :
yield result
if progress is not None and message :
progress ( percent , message )
if training_state :
training_state . process . stdout . close ( )
return_code = training_state . process . wait ( )
training_state = None
#if return_code:
# raise subprocess.CalledProcessError(return_code, cmd)
def get_training_losses ( ) :
global training_state
@ -866,6 +891,9 @@ def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
if result :
yield result
if progress is not None and message :
progress ( percent , message )
def stop_training ( ) :
global training_state
if training_state is None :
@ -910,10 +938,10 @@ def convert_to_halfp():
def whisper_transcribe ( file , language = None ) :
# shouldn't happen, but it's for safety
if not whisper_model :
load_whisper_model ( language = language if language else b ' en ' )
load_whisper_model ( language = language )
if not args . whisper_cpp :
return whisper_model . transcribe ( file , language = language if language else " English " )
return whisper_model . transcribe ( file , language = language )
res = whisper_model . transcribe ( file )
segments = whisper_model . extract_text_and_timestamps ( res )
@ -945,11 +973,8 @@ def prepare_dataset( files, outdir, language=None, progress=None ):
transcription = [ ]
for file in enumerate_progress ( files , desc = " Iterating through voice files " , progress = progress ) :
print ( f " Transcribing file: { file } " )
result = whisper_transcribe ( file , language = language ) # whisper_model.transcribe(file, language=language if language else "English")
result = whisper_transcribe ( file , language = language )
results [ os . path . basename ( file ) ] = result
print ( f " Transcribed file: { file } , { len ( result [ ' segments ' ] ) } found. " )
waveform , sampling_rate = torchaudio . load ( file )
@ -988,7 +1013,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
def schedule_learning_rate ( iterations , schedule = EPOCH_SCHEDULE ) :
return [ int ( iterations * d ) for d in schedule ]
def optimize_training_settings ( epochs , learning_rate , text_ce_lr_weight , learning_rate_schedule , batch_size , gradient_accumulation_size , print_rate , save_rate , resume_path , half_p , bnb , source_model, voice ) :
def optimize_training_settings ( epochs , learning_rate , text_ce_lr_weight , learning_rate_schedule , batch_size , gradient_accumulation_size , print_rate , save_rate , resume_path , half_p , bnb , workers, source_model, voice ) :
name = f " { voice } -finetune "
dataset_name = f " { voice } -train "
dataset_path = f " ./training/ { voice } /train.txt "
@ -1065,7 +1090,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
messages
)
def save_training_settings ( iterations = None , learning_rate = None , text_ce_lr_weight = None , learning_rate_schedule = None , batch_size = None , gradient_accumulation_size = None , print_rate = None , save_rate = None , name = None , dataset_name = None , dataset_path = None , validation_name = None , validation_path = None , output_name = None , resume_path = None , half_p = None , bnb = None , source_model= None ) :
def save_training_settings ( iterations = None , learning_rate = None , text_ce_lr_weight = None , learning_rate_schedule = None , batch_size = None , gradient_accumulation_size = None , print_rate = None , save_rate = None , name = None , dataset_name = None , dataset_path = None , validation_name = None , validation_path = None , output_name = None , resume_path = None , half_p = None , bnb = None , workers= None , source_model= None ) :
if not source_model :
source_model = f " ./models/tortoise/autoregressive { ' _half ' if half_p else ' ' } .pth "
@ -1090,6 +1115,8 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
' float16 ' : ' true ' if half_p else ' false ' ,
' bitsandbytes ' : ' true ' if bnb else ' false ' ,
' workers ' : workers if workers else 2 ,
}
if resume_path :
@ -1581,9 +1608,9 @@ def unload_tts():
global tts
if tts :
print ( " Unloading TTS " )
del tts
tts = None
print ( " Unloaded TTS " )
do_gc ( )
def reload_tts ( model = None ) :
@ -1656,55 +1683,44 @@ def unload_voicefixer():
global voicefixer
if voicefixer :
print ( " Unloading Voicefixer " )
del voicefixer
voicefixer = None
print ( " Unloaded Voicefixer " )
do_gc ( )
def load_whisper_model ( name= None , progress = None , language = b ' en ' ) :
def load_whisper_model ( language= None , model_ name= None , progress = None ) :
global whisper_model
if not name:
name = args . whisper_model
if not model_ name:
model_ name = args . whisper_model
else :
args . whisper_model = name
args . whisper_model = model_ name
save_args_settings ( )
notify_progress ( f " Loading Whisper model: { args . whisper_model } " , progress )
if language and f ' { model_name } . { language } ' in WHISPER_SPECIALIZED_MODELS :
model_name = f ' { model_name } . { language } '
print ( f " Loading specialized model for language: { language } " )
notify_progress ( f " Loading Whisper model: { model_name } " , progress )
if args . whisper_cpp :
from whispercpp import Whisper
whisper_model = Whisper ( name , models_dir = ' ./models/ ' , language = language . encode ( ' ascii ' ) )
if not language :
language = ' auto '
whisper_model = Whisper ( model_name , models_dir = ' ./models/ ' , language = language . encode ( ' ascii ' ) )
else :
import whisper
whisper_model = whisper . load_model ( args . whisper_model )
whisper_model = whisper . load_model ( model_name )
print ( " Loaded Whisper model " )
def unload_whisper ( ) :
global whisper_model
if whisper_model :
print ( " Unloading Whisper " )
del whisper_model
whisper_model = None
print ( " Unloaded Whisper " )
do_gc ( )
"""
def update_whisper_model ( name , progress = None ) :
if not name :
return
args . whisper_model = name
save_args_settings ( )
global whisper_model
if whisper_model :
unload_whisper ( )
load_whisper_model ( name )
else :
args . whisper_model = name
save_args_settings ( )
"""
do_gc ( )