@ -34,8 +34,6 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_
from tortoise . utils . text import split_and_recombine_text
from tortoise . utils . text import split_and_recombine_text
from tortoise . utils . device import get_device_name , set_device_name
from tortoise . utils . device import get_device_name , set_device_name
import whisper
MODELS [ ' dvae.pth ' ] = " https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth "
MODELS [ ' dvae.pth ' ] = " https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth "
args = None
args = None
@ -46,7 +44,6 @@ voicefixer = None
whisper_model = None
whisper_model = None
training_state = None
training_state = None
def generate (
def generate (
text ,
text ,
delimiter ,
delimiter ,
@ -501,9 +498,12 @@ class TrainingState():
match = re . findall ( r ' iter: ([ \ d,]+) ' , line )
match = re . findall ( r ' iter: ([ \ d,]+) ' , line )
if match and len ( match ) > 0 :
if match and len ( match ) > 0 :
self . it = int ( match [ 0 ] . replace ( " , " , " " ) )
self . it = int ( match [ 0 ] . replace ( " , " , " " ) )
self . checkpoints = int ( ( self . its - self . it ) / self . config [ ' logger ' ] [ ' save_checkpoint_freq ' ] )
else :
else :
lapsed = False
lapsed = False
message = None
if line . find ( ' % | ' ) > 0 :
if line . find ( ' % | ' ) > 0 :
match = re . findall ( r ' ( \ d+) % \ |(.+?) \ | ( \ d+| \ ?) \ /( \ d+| \ ?) \ [(.+?)<(.+?), +(.+?) \ ] ' , line )
match = re . findall ( r ' ( \ d+) % \ |(.+?) \ | ( \ d+| \ ?) \ /( \ d+| \ ?) \ [(.+?)<(.+?), +(.+?) \ ] ' , line )
if match and len ( match ) > 0 :
if match and len ( match ) > 0 :
@ -516,8 +516,6 @@ class TrainingState():
until = match [ 5 ]
until = match [ 5 ]
rate = match [ 6 ]
rate = match [ 6 ]
epoch_percent = self . it / float ( self . its ) # self.epoch / float(self.epochs)
last_step = self . last_step
last_step = self . last_step
self . last_step = step
self . last_step = step
if last_step < step :
if last_step < step :
@ -530,11 +528,13 @@ class TrainingState():
self . it_time_delta = self . it_time_end - self . it_time_start
self . it_time_delta = self . it_time_end - self . it_time_start
self . it_time_start = time . time ( )
self . it_time_start = time . time ( )
try :
try :
rate = f ' [ { " {:.3f} " . format ( self . it_time_delta ) } s/it ] ' if self . it_time_delta > = 1 else f ' [ { " {:.3f} " . format ( 1 / self . it_time_delta ) } it/s ] '
rate = f ' { " {:.3f} " . format ( self . it_time_delta ) } s/it ' if self . it_time_delta > = 1 else f ' { " {:.3f} " . format ( 1 / self . it_time_delta ) } it/s '
self . it_rate = rate
self . it_rate = rate
except Exception as e :
except Exception as e :
pass
pass
message = f ' [ { self . epoch } / { self . epochs } , { self . it } / { self . its } , { step } / { steps } ] [ETA: { self . eta_hhmmss } ] [ { self . epoch_rate } , { self . it_rate } ] { self . status } '
"""
"""
# I wanted frequently updated ETA, but I can't wrap my noggin around getting it to work on an empty belly
# I wanted frequently updated ETA, but I can't wrap my noggin around getting it to work on an empty belly
# will fix later
# will fix later
@ -550,13 +550,6 @@ class TrainingState():
pass
pass
"""
"""
message = f ' [ { self . epoch } / { self . epochs } ] [ { self . it } / { self . its } ] [ETA: { self . eta_hhmmss } ] { self . epoch_rate } / { self . it_rate } { self . status } '
if progress is not None :
progress ( epoch_percent , message )
# print(f'{"{:.3f}".format(percent*100)}% {message}')
self . buffer . append ( f ' [ { " {:.3f} " . format ( epoch_percent * 100 ) } % / { " {:.3f} " . format ( percent * 100 ) } %] { message } ' )
if lapsed :
if lapsed :
self . epoch = self . epoch + 1
self . epoch = self . epoch + 1
self . it = int ( self . epoch * ( self . dataset_size / self . batch_size ) )
self . it = int ( self . epoch * ( self . dataset_size / self . batch_size ) )
@ -564,7 +557,7 @@ class TrainingState():
self . epoch_time_end = time . time ( )
self . epoch_time_end = time . time ( )
self . epoch_time_delta = self . epoch_time_end - self . epoch_time_start
self . epoch_time_delta = self . epoch_time_end - self . epoch_time_start
self . epoch_time_start = time . time ( )
self . epoch_time_start = time . time ( )
self . epoch_rate = f ' [ { " {:.3f} " . format ( self . epoch_time_delta ) } s/epoch ] ' if self . epoch_time_delta > = 1 else f ' [ { " {:.3f} " . format ( 1 / self . epoch_time_delta ) } epoch/s ] ' # I doubt anyone will have it/s rates, but its here
self . epoch_rate = f ' { " {:.3f} " . format ( self . epoch_time_delta ) } s/epoch ' if self . epoch_time_delta > = 1 else f ' { " {:.3f} " . format ( 1 / self . epoch_time_delta ) } epoch/s ' # I doubt anyone will have it/s rates, but its here
#self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
#self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
self . epoch_time_deltas = self . epoch_time_deltas + self . epoch_time_delta
self . epoch_time_deltas = self . epoch_time_deltas + self . epoch_time_delta
@ -576,14 +569,12 @@ class TrainingState():
except Exception as e :
except Exception as e :
pass
pass
percent = self . epoch / float ( self . epochs )
if message :
message = f ' [ { self . epoch } / { self . epochs } ] [ { self . it } / { self . its } ] [ETA: { self . eta_hhmmss } ] { self . epoch_rate } / { self . it_rate } { self . status } '
percent = self . it / float ( self . its ) # self.epoch / float(self.epochs)
if progress is not None :
if progress is not None :
progress ( percent , message )
progress ( percent , message )
print ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
self . buffer . append ( f ' [ { " {:.3f} " . format ( percent * 100 ) } %] { message } ' )
self . buffer . append ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
if line . find ( ' INFO: [epoch: ' ) > = 0 :
if line . find ( ' INFO: [epoch: ' ) > = 0 :
# easily rip out our stats...
# easily rip out our stats...
@ -677,12 +668,36 @@ def convert_to_halfp():
torch . save ( model , outfile )
torch . save ( model , outfile )
print ( f ' Converted model to half precision: { outfile } ' )
print ( f ' Converted model to half precision: { outfile } ' )
def whisper_transcribe ( file , language = None ) :
# shouldn't happen, but it's for safety
if not whisper_model :
load_whisper_model ( language = language if language else b ' en ' )
if not args . whisper_cpp :
return whisper_model . transcribe ( file , language = language if language else " English " )
res = whisper_model . transcribe ( file )
segments = whisper_model . extract_text_and_timestamps ( res )
result = {
' segments ' : [ ]
}
for segment in segments :
reparsed = {
' start ' : segment [ 0 ] ,
' end ' : segment [ 1 ] ,
' text ' : segment [ 2 ] ,
}
result [ ' segments ' ] . append ( reparsed )
return result
def prepare_dataset ( files , outdir , language = None , progress = None ) :
def prepare_dataset ( files , outdir , language = None , progress = None ) :
unload_tts ( )
unload_tts ( )
global whisper_model
global whisper_model
if whisper_model is None :
if whisper_model is None :
load_whisper_model ( )
load_whisper_model ( language = language )
os . makedirs ( outdir , exist_ok = True )
os . makedirs ( outdir , exist_ok = True )
@ -693,7 +708,7 @@ def prepare_dataset( files, outdir, language=None, progress=None ):
for file in enumerate_progress ( files , desc = " Iterating through voice files " , progress = progress ) :
for file in enumerate_progress ( files , desc = " Iterating through voice files " , progress = progress ) :
print ( f " Transcribing file: { file } " )
print ( f " Transcribing file: { file } " )
result = whisper_ model. transcribe( file , language = language if language else " English " )
result = whisper_ transcribe( file , language = language ) # whisper_model.transcribe(file, language=language if language else "English" )
results [ os . path . basename ( file ) ] = result
results [ os . path . basename ( file ) ] = result
print ( f " Transcribed file: { file } , { len ( result [ ' segments ' ] ) } found. " )
print ( f " Transcribed file: { file } , { len ( result [ ' segments ' ] ) } found. " )
@ -1037,12 +1052,14 @@ def setup_args():
' defer-tts-load ' : False ,
' defer-tts-load ' : False ,
' device-override ' : None ,
' device-override ' : None ,
' prune-nonfinal-outputs ' : True ,
' prune-nonfinal-outputs ' : True ,
' whisper-model ' : " base " ,
' autoregressive-model ' : None ,
' concurrency-count ' : 2 ,
' concurrency-count ' : 2 ,
' output-sample-rate ' : 44100 ,
' output-sample-rate ' : 44100 ,
' output-volume ' : 1 ,
' output-volume ' : 1 ,
' autoregressive-model ' : None ,
' whisper-model ' : " base " ,
' whisper-cpp ' : False ,
' training-default-halfp ' : False ,
' training-default-halfp ' : False ,
' training-default-bnb ' : True ,
' training-default-bnb ' : True ,
}
}
@ -1067,13 +1084,15 @@ def setup_args():
parser . add_argument ( " --defer-tts-load " , default = default_arguments [ ' defer-tts-load ' ] , action = ' store_true ' , help = " Defers loading TTS model " )
parser . add_argument ( " --defer-tts-load " , default = default_arguments [ ' defer-tts-load ' ] , action = ' store_true ' , help = " Defers loading TTS model " )
parser . add_argument ( " --prune-nonfinal-outputs " , default = default_arguments [ ' prune-nonfinal-outputs ' ] , action = ' store_true ' , help = " Deletes non-final output files on completing a generation " )
parser . add_argument ( " --prune-nonfinal-outputs " , default = default_arguments [ ' prune-nonfinal-outputs ' ] , action = ' store_true ' , help = " Deletes non-final output files on completing a generation " )
parser . add_argument ( " --device-override " , default = default_arguments [ ' device-override ' ] , help = " A device string to override pass through Torch " )
parser . add_argument ( " --device-override " , default = default_arguments [ ' device-override ' ] , help = " A device string to override pass through Torch " )
parser . add_argument ( " --whisper-model " , default = default_arguments [ ' whisper-model ' ] , help = " Specifies which whisper model to use for transcription. " )
parser . add_argument ( " --autoregressive-model " , default = default_arguments [ ' autoregressive-model ' ] , help = " Specifies which autoregressive model to use for sampling. " )
parser . add_argument ( " --sample-batch-size " , default = default_arguments [ ' sample-batch-size ' ] , type = int , help = " Sets how many batches to use during the autoregressive samples pass " )
parser . add_argument ( " --sample-batch-size " , default = default_arguments [ ' sample-batch-size ' ] , type = int , help = " Sets how many batches to use during the autoregressive samples pass " )
parser . add_argument ( " --concurrency-count " , type = int , default = default_arguments [ ' concurrency-count ' ] , help = " How many Gradio events to process at once " )
parser . add_argument ( " --concurrency-count " , type = int , default = default_arguments [ ' concurrency-count ' ] , help = " How many Gradio events to process at once " )
parser . add_argument ( " --output-sample-rate " , type = int , default = default_arguments [ ' output-sample-rate ' ] , help = " Sample rate to resample the output to (from 24KHz) " )
parser . add_argument ( " --output-sample-rate " , type = int , default = default_arguments [ ' output-sample-rate ' ] , help = " Sample rate to resample the output to (from 24KHz) " )
parser . add_argument ( " --output-volume " , type = float , default = default_arguments [ ' output-volume ' ] , help = " Adjusts volume of output " )
parser . add_argument ( " --output-volume " , type = float , default = default_arguments [ ' output-volume ' ] , help = " Adjusts volume of output " )
parser . add_argument ( " --autoregressive-model " , default = default_arguments [ ' autoregressive-model ' ] , help = " Specifies which autoregressive model to use for sampling. " )
parser . add_argument ( " --whisper-model " , default = default_arguments [ ' whisper-model ' ] , help = " Specifies which whisper model to use for transcription. " )
parser . add_argument ( " --whisper-cpp " , default = default_arguments [ ' whisper-cpp ' ] , action = ' store_true ' , help = " Leverages lightmare/whispercpp for transcription " )
parser . add_argument ( " --training-default-halfp " , action = ' store_true ' , default = default_arguments [ ' training-default-halfp ' ] , help = " Training default: halfp " )
parser . add_argument ( " --training-default-halfp " , action = ' store_true ' , default = default_arguments [ ' training-default-halfp ' ] , help = " Training default: halfp " )
parser . add_argument ( " --training-default-bnb " , action = ' store_true ' , default = default_arguments [ ' training-default-bnb ' ] , help = " Training default: bnb " )
parser . add_argument ( " --training-default-bnb " , action = ' store_true ' , default = default_arguments [ ' training-default-bnb ' ] , help = " Training default: bnb " )
@ -1103,7 +1122,7 @@ def setup_args():
return args
return args
def update_args ( listen , share , check_for_updates , models_from_local_only , low_vram , embed_output_metadata , latents_lean_and_mean , voice_fixer , voice_fixer_use_cuda , force_cpu_for_conditioning_latents , defer_tts_load , prune_nonfinal_outputs , device_override , sample_batch_size , concurrency_count , output_sample_rate , output_volume , training_default_halfp, training_default_bnb ) :
def update_args ( listen , share , check_for_updates , models_from_local_only , low_vram , embed_output_metadata , latents_lean_and_mean , voice_fixer , voice_fixer_use_cuda , force_cpu_for_conditioning_latents , defer_tts_load , prune_nonfinal_outputs , device_override , sample_batch_size , concurrency_count , output_sample_rate , output_volume , autoregressive_model, whisper_model , whisper_cpp , training_default_halfp, training_default_bnb ) :
global args
global args
args . listen = listen
args . listen = listen
@ -1123,6 +1142,11 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
args . concurrency_count = concurrency_count
args . concurrency_count = concurrency_count
args . output_sample_rate = output_sample_rate
args . output_sample_rate = output_sample_rate
args . output_volume = output_volume
args . output_volume = output_volume
args . autoregressive_model = autoregressive_model
args . whisper_model = whisper_model
args . whisper_cpp = whisper_cpp
args . training_default_halfp = training_default_halfp
args . training_default_halfp = training_default_halfp
args . training_default_bnb = training_default_bnb
args . training_default_bnb = training_default_bnb
@ -1140,8 +1164,6 @@ def save_args_settings():
' defer-tts-load ' : args . defer_tts_load ,
' defer-tts-load ' : args . defer_tts_load ,
' prune-nonfinal-outputs ' : args . prune_nonfinal_outputs ,
' prune-nonfinal-outputs ' : args . prune_nonfinal_outputs ,
' device-override ' : args . device_override ,
' device-override ' : args . device_override ,
' whisper-model ' : args . whisper_model ,
' autoregressive-model ' : args . autoregressive_model ,
' sample-batch-size ' : args . sample_batch_size ,
' sample-batch-size ' : args . sample_batch_size ,
' embed-output-metadata ' : args . embed_output_metadata ,
' embed-output-metadata ' : args . embed_output_metadata ,
' latents-lean-and-mean ' : args . latents_lean_and_mean ,
' latents-lean-and-mean ' : args . latents_lean_and_mean ,
@ -1151,6 +1173,10 @@ def save_args_settings():
' output-sample-rate ' : args . output_sample_rate ,
' output-sample-rate ' : args . output_sample_rate ,
' output-volume ' : args . output_volume ,
' output-volume ' : args . output_volume ,
' autoregressive-model ' : args . autoregressive_model ,
' whisper-model ' : args . whisper_model ,
' whisper-cpp ' : args . whisper_cpp ,
' training-default-halfp ' : args . training_default_halfp ,
' training-default-halfp ' : args . training_default_halfp ,
' training-default-bnb ' : args . training_default_bnb ,
' training-default-bnb ' : args . training_default_bnb ,
}
}
@ -1292,9 +1318,7 @@ def update_autoregressive_model(autoregressive_model_path):
if not tts :
if not tts :
if tts_loading :
if tts_loading :
raise Exception ( " TTS is still initializing... " )
raise Exception ( " TTS is still initializing... " )
return
load_tts ( model = autoregressive_model_path )
return # redundant to proceed onward
print ( f " Loading model: { autoregressive_model_path } " )
print ( f " Loading model: { autoregressive_model_path } " )
@ -1348,7 +1372,7 @@ def unload_voicefixer():
do_gc ( )
do_gc ( )
def load_whisper_model ( name = None , progress = None ):
def load_whisper_model ( name = None , progress = None , language = b ' en ' ):
global whisper_model
global whisper_model
if not name :
if not name :
@ -1358,6 +1382,11 @@ def load_whisper_model(name=None, progress=None):
save_args_settings ( )
save_args_settings ( )
notify_progress ( f " Loading Whisper model: { args . whisper_model } " , progress )
notify_progress ( f " Loading Whisper model: { args . whisper_model } " , progress )
if args . whisper_cpp :
from whispercpp import Whisper
whisper_model = Whisper ( name , models_dir = ' ./models/ ' , language = language )
else :
import whisper
whisper_model = whisper . load_model ( args . whisper_model )
whisper_model = whisper . load_model ( args . whisper_model )
print ( " Loaded Whisper model " )
print ( " Loaded Whisper model " )
@ -1372,10 +1401,13 @@ def unload_whisper():
do_gc ( )
do_gc ( )
"""
def update_whisper_model ( name , progress = None ) :
def update_whisper_model ( name , progress = None ) :
if not name :
if not name :
return
return
args . whisper_model = name
save_args_settings ( )
global whisper_model
global whisper_model
if whisper_model :
if whisper_model :
@ -1384,3 +1416,4 @@ def update_whisper_model(name, progress=None):
else :
else :
args . whisper_model = name
args . whisper_model = name
save_args_settings ( )
save_args_settings ( )
"""