@ -17,6 +17,7 @@ import urllib.request
import signal
import signal
import gc
import gc
import subprocess
import subprocess
import yaml
import tqdm
import tqdm
import torch
import torch
@ -26,6 +27,7 @@ import gradio as gr
import gradio . utils
import gradio . utils
from datetime import datetime
from datetime import datetime
from datetime import timedelta
from tortoise . api import TextToSpeech , MODELS , get_model_path
from tortoise . api import TextToSpeech , MODELS , get_model_path
from tortoise . utils . audio import load_audio , load_voice , load_voices , get_voice_dir
from tortoise . utils . audio import load_audio , load_voice , load_voices , get_voice_dir
@ -42,7 +44,7 @@ tts_loading = False
webui = None
webui = None
voicefixer = None
voicefixer = None
whisper_model = None
whisper_model = None
training_ process = None
training_ state = None
def generate (
def generate (
@ -434,100 +436,128 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
return voice
return voice
def run_training ( config_path , verbose = False , buffer_size = 8 , progress = gr . Progress ( track_tqdm = True ) ) :
# superfluous, but it cleans up some things
global training_process
class TrainingState ( ) :
def __init__ ( self , config_path , buffer_size = 8 ) :
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
self . cmd = [ ' train.bat ' , config_path ] if os . name == " nt " else [ ' bash ' , ' ./train.sh ' , config_path ]
torch . multiprocessing . freeze_support ( )
unload_tts ( )
# parse config to get its iteration
unload_whisper ( )
with open ( config_path , ' r ' ) as file :
unload_voicefixer ( )
self . config = yaml . safe_load ( file )
cmd = [ ' train.bat ' , config_path ] if os . name == " nt " else [ ' bash ' , ' ./train.sh ' , config_path ]
self . it = 0
print ( " Spawning process: " , " " . join ( cmd ) )
self . its = self . config [ ' train ' ] [ ' niter ' ]
training_process = subprocess . Popen ( cmd , stdout = subprocess . PIPE , stderr = subprocess . STDOUT , universal_newlines = True )
# parse config to get its iteration
self . checkpoint = 0
import yaml
self . checkpoints = int ( self . its / self . config [ ' logger ' ] [ ' save_checkpoint_freq ' ] )
with open ( config_path , ' r ' ) as file :
config = yaml . safe_load ( file )
it = 0
self . buffer = [ ]
its = config [ ' train ' ] [ ' niter ' ]
checkpoint = 0
self . open_state = False
checkpoints = its / config [ ' logger ' ] [ ' save_checkpoint_freq ' ]
self . training_started = False
buffer_size = 8
self . info = { }
open_state = False
self . status = " "
training_started = False
yield " " . join ( cmd )
self . it_rate = " "
self . it_time_start = 0
info = { }
self . it_time_end = 0
buffer = [ ]
self . eta = " ? "
infos = [ ]
yields = True
status = " "
it_rate = " "
print ( " Spawning process: " , " " . join ( self . cmd ) )
it_time_start = 0
self . process = subprocess . Popen ( self . cmd , stdout = subprocess . PIPE , stderr = subprocess . STDOUT , universal_newlines = True )
it_time_end = 0
for line in iter ( training_process . stdout . readline , " " ) :
def parse ( self , line , verbose = False , buffer_size = 8 , progress = None ) :
buffer . append ( f ' { line } ' )
self . buffer . append ( f ' { line } ' )
# rip out iteration info
# rip out iteration info
if not training_started :
if not self . training_started :
if line . find ( ' Start training from epoch ' ) > = 0 :
if line . find ( ' Start training from epoch ' ) > = 0 :
training_started = True
self . it_time_start = time . time ( )
self . training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
match = re . findall ( r ' iter: ([ \ d,]+) ' , line )
match = re . findall ( r ' iter: ([ \ d,]+) ' , line )
if match and len ( match ) > 0 :
if match and len ( match ) > 0 :
it = int ( match [ 0 ] . replace ( " , " , " " ) )
self . it = int ( match [ 0 ] . replace ( " , " , " " ) )
elif progress is not None :
elif progress is not None :
if line . find ( ' 0 % | ' ) == 0 :
if line . find ( ' 0 % | ' ) == 0 :
open_state = True
self . open_state = True
elif line . find ( ' 100 % | ' ) == 0 and open_state :
elif line . find ( ' 100 % | ' ) == 0 and self . open_state :
open_state = False
self . open_state = False
it = it + 1
self . it = self . it + 1
it_time_end = time . time ( )
self . it_time_end = time . time ( )
it_time_delta = it_time_end - it_time_start
self . it_time_delta = self . it_time_end - self . it_time_start
it_time_start = time . time ( )
self . it_time_start = time . time ( )
it_rate = f ' [ { " {:.3f} " . format ( it_time_delta ) } s/it] ' if it_time_delta > = 1 else f ' [ { " {:.3f} " . format ( 1 / it_time_delta ) } it/s] ' # I doubt anyone will have it/s rates, but its here
self . it_rate = f ' [ { " {:.3f} " . format ( self . it_time_delta ) } s/it] ' if self . it_time_delta > = 1 else f ' [ { " {:.3f} " . format ( 1 / self . it_time_delta ) } it/s] ' # I doubt anyone will have it/s rates, but its here
self . eta = ( self . its - self . it ) * self . it_time_delta
self . eta_hhmmss = str ( timedelta ( seconds = int ( self . eta ) ) )
progress ( it / float ( its ) , f ' [ { it } / { its } ] { it_rate } Training... { status } ' )
progress ( self . it / float ( self . its ) , f ' [ { self . it } / { self . its } ] [ETA: { self . eta_hhmmss } ] { self . it_rate } Training... { self . status } ' )
if line . find ( ' INFO: [epoch: ' ) > = 0 :
if line . find ( ' INFO: [epoch: ' ) > = 0 :
# easily rip out our stats...
# easily rip out our stats...
match = re . findall ( r ' \ b([a-z_0-9]+?) \ b: ([0-9] \ .[0-9]+?e[+-] \ d+) \ b ' , line )
match = re . findall ( r ' \ b([a-z_0-9]+?) \ b: ([0-9] \ .[0-9]+?e[+-] \ d+) \ b ' , line )
if match and len ( match ) > 0 :
if match and len ( match ) > 0 :
for k , v in match :
for k , v in match :
info [ k ] = float ( v )
self . info [ k ] = float ( v )
# ...and returns our loss rate
# ...and returns our loss rate
# it would be nice for losses to be shown at every step
# it would be nice for losses to be shown at every step
if ' loss_gpt_total ' in info :
if ' loss_gpt_total ' in self . info :
status = f " Total loss at step { int ( info [ ' step ' ] ) } : { info [ ' loss_gpt_total ' ] } "
# self.info['step'] returns the steps, not iterations, so we won't even bother ripping the reported step count, as iteration count won't get ripped from the regex
self . status = f " Total loss at iteration { self . it } : { self . info [ ' loss_gpt_total ' ] } "
elif line . find ( ' Saving models and training states ' ) > = 0 :
elif line . find ( ' Saving models and training states ' ) > = 0 :
checkpoint = checkpoint + 1
self . checkpoint = self . checkpoint + 1
progress ( checkpoint / float ( checkpoints ) , f ' [ { checkpoint } / { checkpoints } ] Saving checkpoint... ' )
progress ( self . checkpoint / float ( self . checkpoints ) , f ' [ { self . checkpoint } / { self . checkpoints } ] Saving checkpoint... ' )
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
if verbose or not self . training_started :
return " " . join ( self . buffer [ - buffer_size : ] )
if verbose or not training_started :
def run_training ( config_path , verbose = False , buffer_size = 8 , progress = gr . Progress ( track_tqdm = True ) ) :
yield " " . join ( buffer [ - buffer_size : ] )
global training_state
if training_state and training_state . process :
return " Training already in progress "
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
torch . multiprocessing . freeze_support ( )
unload_tts ( )
unload_whisper ( )
unload_voicefixer ( )
training_state = TrainingState ( config_path = config_path , buffer_size = buffer_size )
training_process . stdout . close ( )
for line in iter ( training_state . process . stdout . readline , " " ) :
return_code = training_process . wait ( )
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
training_process = None
res = training_state . parse ( line = line , verbose = verbose , buffer_size = buffer_size , progress = progress )
if res :
yield res
training_state . process . stdout . close ( )
return_code = training_state . process . wait ( )
output = " " . join ( training_state . buffer [ - buffer_size : ] )
training_state = None
#if return_code:
#if return_code:
# raise subprocess.CalledProcessError(return_code, cmd)
# raise subprocess.CalledProcessError(return_code, cmd)
return " " . join ( buffer [ - buffer_size : ] )
return output
def reconnect_training ( config_path , verbose = False , buffer_size = 8 , progress = gr . Progress ( track_tqdm = True ) ) :
global training_state
if not training_state or not training_state . process :
return " Training not in progress "
for line in iter ( training_state . process . stdout . readline , " " ) :
res = training_state . parse ( line = line , verbose = verbose , buffer_size = buffer_size , progress = progress )
if res :
yield res
output = " " . join ( training_state . buffer [ - buffer_size : ] )
return output
def stop_training ( ) :
def stop_training ( ) :
global training_process
global training_process