2023-02-17 00:08:27 +00:00
import os
if ' XDG_CACHE_HOME ' not in os . environ :
os . environ [ ' XDG_CACHE_HOME ' ] = os . path . realpath ( os . path . join ( os . getcwd ( ) , ' ./models/ ' ) )
if ' TORTOISE_MODELS_DIR ' not in os . environ :
os . environ [ ' TORTOISE_MODELS_DIR ' ] = os . path . realpath ( os . path . join ( os . getcwd ( ) , ' ./models/tortoise/ ' ) )
if ' TRANSFORMERS_CACHE ' not in os . environ :
os . environ [ ' TRANSFORMERS_CACHE ' ] = os . path . realpath ( os . path . join ( os . getcwd ( ) , ' ./models/transformers/ ' ) )
import argparse
import time
import json
import base64
import re
import urllib . request
2023-02-18 02:07:22 +00:00
import signal
2023-02-18 20:37:37 +00:00
import gc
2023-02-20 00:21:16 +00:00
import subprocess
2023-03-04 20:42:54 +00:00
import psutil
2023-02-23 06:24:54 +00:00
import yaml
2023-02-17 00:08:27 +00:00
2023-02-18 02:07:22 +00:00
import tqdm
2023-02-17 00:08:27 +00:00
import torch
import torchaudio
import music_tag
import gradio as gr
import gradio . utils
2023-02-28 01:01:50 +00:00
import pandas as pd
2023-02-17 00:08:27 +00:00
from datetime import datetime
2023-02-23 06:24:54 +00:00
from datetime import timedelta
2023-02-17 00:08:27 +00:00
2023-03-07 03:55:35 +00:00
from tortoise . api import TextToSpeech , MODELS , get_model_path , pad_or_truncate
2023-02-17 00:08:27 +00:00
from tortoise . utils . audio import load_audio , load_voice , load_voices , get_voice_dir
from tortoise . utils . text import split_and_recombine_text
from tortoise . utils . device import get_device_name , set_device_name
2023-02-17 19:06:05 +00:00
MODELS [ ' dvae.pth ' ] = " https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth "
2023-03-06 05:21:33 +00:00
WHISPER_MODELS = [ " tiny " , " base " , " small " , " medium " , " large " , " large-v2 " ]
2023-03-05 05:17:19 +00:00
WHISPER_SPECIALIZED_MODELS = [ " tiny.en " , " base.en " , " small.en " , " medium.en " ]
2023-03-06 05:21:33 +00:00
WHISPER_BACKENDS = [ " openai/whisper " , " lightmare/whispercpp " , " m-bain/whisperx " ]
2023-03-07 02:45:22 +00:00
VOCODERS = [ ' univnet ' , ' bigvgan_base_24khz_100band ' ] #, 'bigvgan_24khz_100band']
2023-03-05 17:54:36 +00:00
EPOCH_SCHEDULE = [ 9 , 18 , 25 , 33 ]
2023-02-17 19:06:05 +00:00
2023-02-17 00:08:27 +00:00
args = None
tts = None
2023-02-21 03:00:45 +00:00
tts_loading = False
2023-02-17 00:08:27 +00:00
webui = None
voicefixer = None
2023-02-17 19:06:05 +00:00
whisper_model = None
2023-02-23 06:24:54 +00:00
training_state = None
2023-02-17 05:42:55 +00:00
2023-03-07 04:34:39 +00:00
current_voice = None
2023-02-17 00:08:27 +00:00
def generate (
text ,
delimiter ,
emotion ,
prompt ,
voice ,
mic_audio ,
voice_latents_chunks ,
seed ,
candidates ,
num_autoregressive_samples ,
diffusion_iterations ,
temperature ,
diffusion_sampler ,
breathing_room ,
cvvp_weight ,
top_p ,
diffusion_temperature ,
length_penalty ,
repetition_penalty ,
cond_free_k ,
experimental_checkboxes ,
progress = None
) :
global args
global tts
2023-02-21 03:00:45 +00:00
unload_whisper ( )
unload_voicefixer ( )
2023-02-18 14:10:26 +00:00
if not tts :
2023-02-20 00:21:16 +00:00
# should check if it's loading or unloaded, and load it if it's unloaded
2023-02-21 03:00:45 +00:00
if tts_loading :
raise Exception ( " TTS is still initializing... " )
load_tts ( )
2023-03-07 03:55:35 +00:00
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
2023-02-17 00:08:27 +00:00
2023-02-18 20:37:37 +00:00
do_gc ( )
2023-03-06 21:48:34 +00:00
voice_samples = None
conditioning_latents = None
sample_voice = None
2023-02-17 00:08:27 +00:00
2023-03-07 05:35:21 +00:00
if seed == 0 :
seed = None
2023-03-06 21:48:34 +00:00
2023-03-07 05:35:21 +00:00
def fetch_voice ( voice ) :
2023-03-06 21:48:34 +00:00
print ( f " Loading voice: { voice } " )
2023-03-07 05:35:21 +00:00
sample_voice = None
2023-03-06 21:48:34 +00:00
if voice == " microphone " :
if mic_audio is None :
raise Exception ( " Please provide audio from mic when choosing `microphone` as a voice input " )
voice_samples , conditioning_latents = [ load_audio ( mic_audio , tts . input_sample_rate ) ] , None
elif voice == " random " :
voice_samples , conditioning_latents = None , tts . get_random_conditioning_latents ( )
else :
2023-03-07 05:35:21 +00:00
if progress is not None :
progress ( 0 , desc = f " Loading voice: { voice } " )
2023-03-06 21:48:34 +00:00
2023-03-07 05:35:21 +00:00
voice_samples , conditioning_latents = load_voice ( voice , model_hash = tts . autoregressive_model_hash )
2023-03-06 21:48:34 +00:00
if voice_samples and len ( voice_samples ) > 0 :
2023-03-07 05:35:21 +00:00
if conditioning_latents is None :
conditioning_latents = compute_latents ( voice = voice , voice_samples = voice_samples , voice_latents_chunks = voice_latents_chunks )
2023-03-06 21:48:34 +00:00
sample_voice = torch . cat ( voice_samples , dim = - 1 ) . squeeze ( ) . cpu ( )
voice_samples = None
2023-02-17 00:08:27 +00:00
2023-03-07 05:35:21 +00:00
return ( voice_samples , conditioning_latents , sample_voice )
2023-03-07 04:34:39 +00:00
2023-03-06 21:48:34 +00:00
def get_settings ( override = None ) :
settings = {
' temperature ' : float ( temperature ) ,
' top_p ' : float ( top_p ) ,
' diffusion_temperature ' : float ( diffusion_temperature ) ,
' length_penalty ' : float ( length_penalty ) ,
' repetition_penalty ' : float ( repetition_penalty ) ,
' cond_free_k ' : float ( cond_free_k ) ,
' num_autoregressive_samples ' : num_autoregressive_samples ,
' sample_batch_size ' : args . sample_batch_size ,
' diffusion_iterations ' : diffusion_iterations ,
2023-03-07 05:35:21 +00:00
' voice_samples ' : None ,
' conditioning_latents ' : None ,
2023-03-06 21:48:34 +00:00
' use_deterministic_seed ' : seed ,
' return_deterministic_state ' : True ,
' k ' : candidates ,
' diffusion_sampler ' : diffusion_sampler ,
' breathing_room ' : breathing_room ,
' progress ' : progress ,
' half_p ' : " Half Precision " in experimental_checkboxes ,
' cond_free ' : " Conditioning-Free " in experimental_checkboxes ,
' cvvp_amount ' : cvvp_weight ,
2023-03-07 05:35:21 +00:00
' autoregressive_model ' : args . autoregressive_model ,
2023-03-06 21:48:34 +00:00
}
2023-02-17 00:08:27 +00:00
2023-03-06 21:48:34 +00:00
# could be better to just do a ternary on everything above, but i am not a professional
2023-03-07 05:35:21 +00:00
selected_voice = voice
2023-03-06 21:48:34 +00:00
if override is not None :
if ' voice ' in override :
2023-03-07 05:35:21 +00:00
selected_voice = override [ ' voice ' ]
2023-03-06 21:48:34 +00:00
for k in override :
if k not in settings :
continue
settings [ k ] = override [ k ]
2023-02-17 00:08:27 +00:00
2023-03-07 05:35:21 +00:00
if settings [ ' autoregressive_model ' ] is not None :
if settings [ ' autoregressive_model ' ] == " auto " :
settings [ ' autoregressive_model ' ] = deduce_autoregressive_model ( selected_voice )
tts . load_autoregressive_model ( settings [ ' autoregressive_model ' ] )
settings [ ' voice_samples ' ] , settings [ ' conditioning_latents ' ] , _ = fetch_voice ( voice = selected_voice )
2023-03-06 21:48:34 +00:00
# clamp it down for the insane users who want this
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
sample_batch_size = args . sample_batch_size
if not sample_batch_size :
sample_batch_size = tts . autoregressive_batch_size
if num_autoregressive_samples < sample_batch_size :
settings [ ' sample_batch_size ' ] = num_autoregressive_samples
2023-03-07 05:35:21 +00:00
if settings [ ' conditioning_latents ' ] is not None and len ( settings [ ' conditioning_latents ' ] ) == 2 and settings [ ' cvvp_amount ' ] > 0 :
print ( " Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents. " )
settings [ ' cvvp_amount ' ] = 0
2023-03-06 21:48:34 +00:00
return settings
2023-02-21 21:50:05 +00:00
2023-03-06 21:48:34 +00:00
if not delimiter :
2023-02-19 06:04:46 +00:00
delimiter = " \n "
elif delimiter == " \\ n " :
2023-02-17 00:08:27 +00:00
delimiter = " \n "
2023-02-19 06:02:47 +00:00
if delimiter and delimiter != " " and delimiter in text :
2023-02-17 00:08:27 +00:00
texts = text . split ( delimiter )
else :
texts = split_and_recombine_text ( text )
full_start_time = time . time ( )
outdir = f " ./results/ { voice } / "
os . makedirs ( outdir , exist_ok = True )
audio_cache = { }
resample = None
2023-02-20 00:21:16 +00:00
2023-02-17 00:08:27 +00:00
if tts . output_sample_rate != args . output_sample_rate :
resampler = torchaudio . transforms . Resample (
tts . output_sample_rate ,
args . output_sample_rate ,
lowpass_filter_width = 16 ,
rolloff = 0.85 ,
resampling_method = " kaiser_window " ,
beta = 8.555504641634386 ,
)
volume_adjust = torchaudio . transforms . Vol ( gain = args . output_volume , gain_type = " amplitude " ) if args . output_volume != 1 else None
idx = 0
idx_cache = { }
for i , file in enumerate ( os . listdir ( outdir ) ) :
filename = os . path . basename ( file )
extension = os . path . splitext ( filename ) [ 1 ]
if extension != " .json " and extension != " .wav " :
continue
match = re . findall ( rf " ^ { voice } _( \ d+)(?:.+?)? { extension } $ " , filename )
key = int ( match [ 0 ] )
idx_cache [ key ] = True
if len ( idx_cache ) > 0 :
keys = sorted ( list ( idx_cache . keys ( ) ) )
idx = keys [ - 1 ] + 1
2023-02-17 05:42:55 +00:00
idx = pad ( idx , 4 )
2023-02-17 00:08:27 +00:00
def get_name ( line = 0 , candidate = 0 , combined = False ) :
name = f " { idx } "
if combined :
name = f " { name } _combined "
elif len ( texts ) > 1 :
name = f " { name } _ { line } "
if candidates > 1 :
name = f " { name } _ { candidate } "
return name
2023-03-06 23:07:16 +00:00
def get_info ( voice , settings = None , latents = True ) :
info = {
' text ' : text ,
' delimiter ' : ' \\ n ' if delimiter and delimiter == " \n " else delimiter ,
' emotion ' : emotion ,
' prompt ' : prompt ,
' voice ' : voice ,
' seed ' : seed ,
' candidates ' : candidates ,
' num_autoregressive_samples ' : num_autoregressive_samples ,
' diffusion_iterations ' : diffusion_iterations ,
' temperature ' : temperature ,
' diffusion_sampler ' : diffusion_sampler ,
' breathing_room ' : breathing_room ,
' cvvp_weight ' : cvvp_weight ,
' top_p ' : top_p ,
' diffusion_temperature ' : diffusion_temperature ,
' length_penalty ' : length_penalty ,
' repetition_penalty ' : repetition_penalty ,
' cond_free_k ' : cond_free_k ,
' experimentals ' : experimental_checkboxes ,
' time ' : time . time ( ) - full_start_time ,
' datetime ' : datetime . now ( ) . isoformat ( ) ,
' model ' : tts . autoregressive_model_path ,
2023-03-07 04:34:39 +00:00
' model_hash ' : tts . autoregressive_model_hash
2023-03-06 23:07:16 +00:00
}
if settings is not None :
for k in settings :
if k in info :
info [ k ] = settings [ k ]
if ' half_p ' in settings and ' cond_free ' in settings :
info [ ' experimentals ' ] = [ ]
if settings [ ' half_p ' ] :
info [ ' experimentals ' ] . append ( " Half Precision " )
if settings [ ' cond_free ' ] :
info [ ' experimentals ' ] . append ( " Conditioning-Free " )
if latents and " latents " not in info :
voice = info [ ' voice ' ]
latents_path = f ' { get_voice_dir ( ) } / { voice } /cond_latents.pth '
if voice == " random " or voice == " microphone " :
if latents and settings [ ' conditioning_latents ' ] :
dir = f ' { get_voice_dir ( ) } / { voice } / '
if not os . path . isdir ( dir ) :
os . makedirs ( dir , exist_ok = True )
latents_path = f ' { dir } /cond_latents.pth '
torch . save ( conditioning_latents , latents_path )
else :
if settings and " model_hash " in settings :
latents_path = f ' { get_voice_dir ( ) } / { voice } /cond_latents_ { settings [ " model_hash " ] [ : 8 ] } .pth '
2023-03-07 04:34:39 +00:00
else :
2023-03-06 23:07:16 +00:00
latents_path = f ' { get_voice_dir ( ) } / { voice } /cond_latents_ { tts . autoregressive_model_hash [ : 8 ] } .pth '
if latents_path and os . path . exists ( latents_path ) :
try :
with open ( latents_path , ' rb ' ) as f :
info [ ' latents ' ] = base64 . b64encode ( f . read ( ) ) . decode ( " ascii " )
except Exception as e :
pass
return info
2023-02-17 00:08:27 +00:00
for line , cut_text in enumerate ( texts ) :
if emotion == " Custom " :
2023-02-23 13:18:51 +00:00
if prompt and prompt . strip ( ) != " " :
2023-02-17 00:08:27 +00:00
cut_text = f " [ { prompt } ,] { cut_text } "
2023-03-05 23:55:27 +00:00
elif emotion != " None " :
2023-02-17 00:08:27 +00:00
cut_text = f " [I am really { emotion . lower ( ) } ,] { cut_text } "
2023-03-06 21:48:34 +00:00
2023-02-17 00:08:27 +00:00
progress . msg_prefix = f ' [ { str ( line + 1 ) } / { str ( len ( texts ) ) } ] '
print ( f " { progress . msg_prefix } Generating line: { cut_text } " )
start_time = time . time ( )
2023-03-06 21:48:34 +00:00
# do setting editing
match = re . findall ( r ' ^( \ { .+ \ }) (.+?)$ ' , cut_text )
2023-03-06 23:07:16 +00:00
override = None
2023-03-06 21:48:34 +00:00
if match and len ( match ) > 0 :
match = match [ 0 ]
try :
override = json . loads ( match [ 0 ] )
2023-03-07 05:35:21 +00:00
cut_text = match [ 1 ] . strip ( )
2023-03-06 21:48:34 +00:00
except Exception as e :
raise Exception ( " Prompt settings editing requested, but received invalid JSON " )
2023-03-07 05:35:21 +00:00
settings = get_settings ( override = override )
gen , additionals = tts . tts ( cut_text , * * settings )
2023-03-06 21:48:34 +00:00
2023-02-17 00:08:27 +00:00
seed = additionals [ 0 ]
run_time = time . time ( ) - start_time
print ( f " Generating line took { run_time } seconds " )
if not isinstance ( gen , list ) :
gen = [ gen ]
for j , g in enumerate ( gen ) :
audio = g . squeeze ( 0 ) . cpu ( )
name = get_name ( line = line , candidate = j )
2023-03-06 23:07:16 +00:00
2023-03-07 05:35:21 +00:00
settings [ ' text ' ] = cut_text
settings [ ' time ' ] = run_time
settings [ ' datetime ' ] = datetime . now ( ) . isoformat ( ) ,
settings [ ' model ' ] = tts . autoregressive_model_path
settings [ ' model_hash ' ] = tts . autoregressive_model_hash
2023-03-06 23:07:16 +00:00
2023-02-17 00:08:27 +00:00
audio_cache [ name ] = {
' audio ' : audio ,
2023-03-07 05:35:21 +00:00
' settings ' : get_info ( voice = override [ ' voice ' ] if override and ' voice ' in override else voice , settings = settings )
2023-02-17 00:08:27 +00:00
}
# save here in case some error happens mid-batch
torchaudio . save ( f ' { outdir } / { voice } _ { name } .wav ' , audio , tts . output_sample_rate )
2023-02-18 20:37:37 +00:00
del gen
do_gc ( )
2023-02-17 00:08:27 +00:00
for k in audio_cache :
audio = audio_cache [ k ] [ ' audio ' ]
if resampler is not None :
audio = resampler ( audio )
if volume_adjust is not None :
audio = volume_adjust ( audio )
audio_cache [ k ] [ ' audio ' ] = audio
torchaudio . save ( f ' { outdir } / { voice } _ { k } .wav ' , audio , args . output_sample_rate )
2023-03-06 23:07:16 +00:00
2023-02-17 00:08:27 +00:00
output_voices = [ ]
for candidate in range ( candidates ) :
if len ( texts ) > 1 :
audio_clips = [ ]
for line in range ( len ( texts ) ) :
name = get_name ( line = line , candidate = candidate )
audio = audio_cache [ name ] [ ' audio ' ]
audio_clips . append ( audio )
name = get_name ( candidate = candidate , combined = True )
audio = torch . cat ( audio_clips , dim = - 1 )
torchaudio . save ( f ' { outdir } / { voice } _ { name } .wav ' , audio , args . output_sample_rate )
audio = audio . squeeze ( 0 ) . cpu ( )
audio_cache [ name ] = {
' audio ' : audio ,
2023-03-06 23:07:16 +00:00
' settings ' : get_info ( voice = voice ) ,
2023-02-17 00:08:27 +00:00
' output ' : True
}
else :
name = get_name ( candidate = candidate )
audio_cache [ name ] [ ' output ' ] = True
2023-02-20 00:21:16 +00:00
if args . voice_fixer :
if not voicefixer :
2023-02-28 15:36:06 +00:00
progress ( 0 , " Loading voicefix... " )
2023-02-20 00:21:16 +00:00
load_voicefixer ( )
2023-02-21 21:50:05 +00:00
fixed_cache = { }
for name in progress . tqdm ( audio_cache , desc = " Running voicefix... " ) :
del audio_cache [ name ] [ ' audio ' ]
if ' output ' not in audio_cache [ name ] or not audio_cache [ name ] [ ' output ' ] :
continue
path = f ' { outdir } / { voice } _ { name } .wav '
fixed = f ' { outdir } / { voice } _ { name } _fixed.wav '
2023-02-17 00:08:27 +00:00
voicefixer . restore (
input = path ,
output = fixed ,
cuda = get_device_name ( ) == " cuda " and args . voice_fixer_use_cuda ,
#mode=mode,
)
2023-02-21 21:50:05 +00:00
fixed_cache [ f ' { name } _fixed ' ] = {
2023-03-06 23:07:16 +00:00
' settings ' : audio_cache [ name ] [ ' settings ' ] ,
2023-02-21 21:50:05 +00:00
' output ' : True
}
audio_cache [ name ] [ ' output ' ] = False
for name in fixed_cache :
audio_cache [ name ] = fixed_cache [ name ]
for name in audio_cache :
if ' output ' not in audio_cache [ name ] or not audio_cache [ name ] [ ' output ' ] :
if args . prune_nonfinal_outputs :
audio_cache [ name ] [ ' pruned ' ] = True
os . remove ( f ' { outdir } / { voice } _ { name } .wav ' )
continue
output_voices . append ( f ' { outdir } / { voice } _ { name } .wav ' )
if not args . embed_output_metadata :
with open ( f ' { outdir } / { voice } _ { name } .json ' , ' w ' , encoding = " utf-8 " ) as f :
2023-03-06 23:07:16 +00:00
f . write ( json . dumps ( audio_cache [ name ] [ ' settings ' ] , indent = ' \t ' ) )
2023-02-17 00:08:27 +00:00
if args . embed_output_metadata :
for name in progress . tqdm ( audio_cache , desc = " Embedding metadata... " ) :
2023-02-21 21:50:05 +00:00
if ' pruned ' in audio_cache [ name ] and audio_cache [ name ] [ ' pruned ' ] :
continue
2023-02-17 00:08:27 +00:00
metadata = music_tag . load_file ( f " { outdir } / { voice } _ { name } .wav " )
2023-03-06 23:07:16 +00:00
metadata [ ' lyrics ' ] = json . dumps ( audio_cache [ name ] [ ' settings ' ] )
2023-02-17 00:08:27 +00:00
metadata . save ( )
if sample_voice is not None :
sample_voice = ( tts . input_sample_rate , sample_voice . numpy ( ) )
2023-03-06 23:07:16 +00:00
info = get_info ( voice = voice , latents = False )
2023-02-17 00:08:27 +00:00
print ( f " Generation took { info [ ' time ' ] } seconds, saved to ' { output_voices [ 0 ] } ' \n " )
2023-03-07 05:35:21 +00:00
info [ ' seed ' ] = seed
2023-02-17 00:08:27 +00:00
if ' latents ' in info :
del info [ ' latents ' ]
2023-02-17 20:10:27 +00:00
os . makedirs ( ' ./config/ ' , exist_ok = True )
2023-02-17 00:08:27 +00:00
with open ( f ' ./config/generate.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( info , indent = ' \t ' ) )
stats = [
[ seed , " {:.3f} " . format ( info [ ' time ' ] ) ]
]
return (
sample_voice ,
output_voices ,
stats ,
)
2023-02-20 00:21:16 +00:00
def cancel_generate ( ) :
2023-02-24 23:13:13 +00:00
import tortoise . api
tortoise . api . STOP_SIGNAL = True
2023-02-17 20:10:27 +00:00
2023-03-02 00:46:52 +00:00
def hash_file ( path , algo = " md5 " , buffer_size = 0 ) :
import hashlib
hash = None
if algo == " md5 " :
hash = hashlib . md5 ( )
elif algo == " sha1 " :
hash = hashlib . sha1 ( )
else :
raise Exception ( f ' Unknown hash algorithm specified: { algo } ' )
if not os . path . exists ( path ) :
raise Exception ( f ' Path not found: { path } ' )
with open ( path , ' rb ' ) as f :
if buffer_size > 0 :
while True :
data = f . read ( buffer_size )
if not data :
break
hash . update ( data )
else :
hash . update ( f . read ( ) )
return " {0} " . format ( hash . hexdigest ( ) )
2023-03-03 21:13:48 +00:00
def update_baseline_for_latents_chunks ( voice ) :
2023-03-07 04:34:39 +00:00
global current_voice
current_voice = voice
2023-03-03 21:13:48 +00:00
path = f ' { get_voice_dir ( ) } / { voice } / '
if not os . path . isdir ( path ) :
return 1
2023-03-07 03:55:35 +00:00
dataset_file = f ' ./training/ { voice } /train.txt '
if os . path . exists ( dataset_file ) :
return 0 # 0 will leverage using the LJspeech dataset for computing latents
2023-03-03 21:13:48 +00:00
files = os . listdir ( path )
2023-03-05 23:55:27 +00:00
total = 0
2023-03-03 21:13:48 +00:00
total_duration = 0
2023-03-05 23:55:27 +00:00
2023-03-03 21:13:48 +00:00
for file in files :
if file [ - 4 : ] != " .wav " :
continue
2023-03-05 23:55:27 +00:00
2023-03-03 21:13:48 +00:00
metadata = torchaudio . info ( f ' { path } / { file } ' )
duration = metadata . num_channels * metadata . num_frames / metadata . sample_rate
total_duration + = duration
2023-03-05 23:55:27 +00:00
total = total + 1
2023-03-03 21:13:48 +00:00
2023-03-07 03:55:35 +00:00
# brain too fried to figure out a better way
2023-03-05 23:55:27 +00:00
if args . autocalculate_voice_chunk_duration_size == 0 :
return int ( total_duration / total ) if total > 0 else 1
2023-03-03 21:13:48 +00:00
return int ( total_duration / args . autocalculate_voice_chunk_duration_size ) if total_duration > 0 else 1
2023-03-07 03:55:35 +00:00
def compute_latents ( voice = None , voice_samples = None , voice_latents_chunks = 0 , progress = None ) :
2023-02-20 00:21:16 +00:00
global tts
global args
2023-02-21 03:00:45 +00:00
2023-02-20 00:21:16 +00:00
unload_whisper ( )
unload_voicefixer ( )
2023-02-21 03:00:45 +00:00
if not tts :
if tts_loading :
raise Exception ( " TTS is still initializing... " )
load_tts ( )
2023-03-07 03:55:35 +00:00
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
2023-03-07 04:34:39 +00:00
if args . autoregressive_model == " auto " :
tts . load_autoregressive_model ( deduce_autoregressive_model ( voice ) )
2023-03-07 03:55:35 +00:00
if voice :
load_from_dataset = voice_latents_chunks == 0
if load_from_dataset :
dataset_path = f ' ./training/ { voice } /train.txt '
if not os . path . exists ( dataset_path ) :
load_from_dataset = False
else :
with open ( dataset_path , ' r ' , encoding = " utf-8 " ) as f :
lines = f . readlines ( )
print ( " Leveraging LJSpeech dataset for computing latents " )
voice_samples = [ ]
max_length = 0
for line in lines :
filename = f ' ./training/ { voice } / { line . split ( " | " ) [ 0 ] } '
waveform = load_audio ( filename , 22050 )
max_length = max ( max_length , waveform . shape [ - 1 ] )
voice_samples . append ( waveform )
for i in range ( len ( voice_samples ) ) :
voice_samples [ i ] = pad_or_truncate ( voice_samples [ i ] , max_length )
voice_latents_chunks = len ( voice_samples )
if not load_from_dataset :
voice_samples , _ = load_voice ( voice , load_latents = False )
2023-02-20 00:21:16 +00:00
if voice_samples is None :
return
2023-03-07 03:55:35 +00:00
conditioning_latents = tts . get_conditioning_latents ( voice_samples , return_mels = not args . latents_lean_and_mean , slices = voice_latents_chunks , force_cpu = args . force_cpu_for_conditioning_latents , progress = progress )
2023-02-20 00:21:16 +00:00
if len ( conditioning_latents ) == 4 :
conditioning_latents = ( conditioning_latents [ 0 ] , conditioning_latents [ 1 ] , conditioning_latents [ 2 ] , None )
2023-03-07 05:35:21 +00:00
outfile = f ' { get_voice_dir ( ) } / { voice } /cond_latents_ { tts . autoregressive_model_hash [ : 8 ] } .pth '
torch . save ( conditioning_latents , outfile )
print ( f ' Saved voice latents: { outfile } ' )
2023-02-20 00:21:16 +00:00
2023-03-07 03:55:35 +00:00
return conditioning_latents
2023-02-17 19:06:05 +00:00
2023-02-23 06:24:54 +00:00
# superfluous, but it cleans up some things
class TrainingState ( ) :
2023-03-03 04:37:18 +00:00
def __init__ ( self , config_path , keep_x_past_datasets = 0 , start = True , gpus = 1 ) :
2023-02-23 06:24:54 +00:00
# parse config to get its iteration
with open ( config_path , ' r ' ) as file :
self . config = yaml . safe_load ( file )
2023-02-17 20:10:27 +00:00
2023-03-04 17:37:08 +00:00
self . killed = False
2023-02-28 01:01:50 +00:00
self . dataset_dir = f " ./training/ { self . config [ ' name ' ] } / "
2023-02-23 15:38:04 +00:00
self . batch_size = self . config [ ' datasets ' ] [ ' train ' ] [ ' batch_size ' ]
2023-02-23 15:31:43 +00:00
self . dataset_path = self . config [ ' datasets ' ] [ ' train ' ] [ ' path ' ]
with open ( self . dataset_path , ' r ' , encoding = " utf-8 " ) as f :
self . dataset_size = len ( f . readlines ( ) )
2023-02-23 06:24:54 +00:00
self . it = 0
self . its = self . config [ ' train ' ] [ ' niter ' ]
2023-02-19 05:05:30 +00:00
2023-02-23 15:31:43 +00:00
self . epoch = 0
2023-02-23 15:38:04 +00:00
self . epochs = int ( self . its * self . batch_size / self . dataset_size )
2023-02-23 15:31:43 +00:00
2023-02-23 06:24:54 +00:00
self . checkpoint = 0
self . checkpoints = int ( self . its / self . config [ ' logger ' ] [ ' save_checkpoint_freq ' ] )
2023-02-19 05:05:30 +00:00
2023-02-23 06:24:54 +00:00
self . buffer = [ ]
2023-02-19 05:05:30 +00:00
2023-02-23 06:24:54 +00:00
self . open_state = False
self . training_started = False
2023-02-19 05:05:30 +00:00
2023-02-23 06:24:54 +00:00
self . info = { }
2023-02-19 05:05:30 +00:00
2023-02-23 15:31:43 +00:00
self . epoch_rate = " "
self . epoch_time_start = 0
self . epoch_time_end = 0
2023-02-25 16:44:25 +00:00
self . epoch_time_deltas = 0
self . epoch_taken = 0
2023-02-25 13:55:25 +00:00
self . it_rate = " "
self . it_time_start = 0
self . it_time_end = 0
2023-02-25 16:44:25 +00:00
self . it_time_deltas = 0
self . it_taken = 0
2023-02-25 13:55:25 +00:00
self . last_step = 0
2023-02-23 06:24:54 +00:00
self . eta = " ? "
2023-02-23 15:31:43 +00:00
self . eta_hhmmss = " ? "
2023-02-20 22:56:39 +00:00
2023-03-05 23:55:27 +00:00
self . nan_detected = False
2023-03-02 00:46:52 +00:00
self . last_info_check_at = 0
2023-03-05 06:45:07 +00:00
self . statistics = [ ]
2023-02-28 06:18:18 +00:00
self . losses = [ ]
2023-03-05 07:37:27 +00:00
self . metrics = {
' step ' : " " ,
' rate ' : " " ,
' loss ' : " " ,
}
2023-02-28 01:01:50 +00:00
2023-03-05 06:45:07 +00:00
self . loss_milestones = [ 1.0 , 0.15 , 0.05 ]
2023-02-28 01:01:50 +00:00
self . load_losses ( )
2023-03-02 01:35:12 +00:00
if keep_x_past_datasets > 0 :
self . cleanup_old ( keep = keep_x_past_datasets )
if start :
2023-03-03 04:37:18 +00:00
self . spawn_process ( config_path = config_path , gpus = gpus )
def spawn_process ( self , config_path , gpus = 1 ) :
2023-03-04 20:42:54 +00:00
self . cmd = [ ' train.bat ' , config_path ] if os . name == " nt " else [ ' ./train.sh ' , str ( int ( gpus ) ) , config_path ]
2023-02-28 01:01:50 +00:00
2023-02-23 06:24:54 +00:00
print ( " Spawning process: " , " " . join ( self . cmd ) )
self . process = subprocess . Popen ( self . cmd , stdout = subprocess . PIPE , stderr = subprocess . STDOUT , universal_newlines = True )
2023-02-20 22:56:39 +00:00
2023-03-02 00:46:52 +00:00
def load_losses ( self , update = False ) :
2023-03-01 19:32:11 +00:00
if not os . path . isdir ( f ' { self . dataset_dir } /tb_logger/ ' ) :
2023-02-28 01:01:50 +00:00
return
2023-03-01 19:32:11 +00:00
try :
from tensorboard . backend . event_processing import event_accumulator
use_tensorboard = True
except Exception as e :
use_tensorboard = False
2023-03-02 00:46:52 +00:00
keys = [ ' loss_text_ce ' , ' loss_mel_ce ' , ' loss_gpt_total ' ]
infos = { }
highest_step = self . last_info_check_at
2023-03-04 15:55:06 +00:00
if not update :
2023-03-05 06:45:07 +00:00
self . statistics = [ ]
2023-03-04 15:55:06 +00:00
2023-03-01 19:32:11 +00:00
if use_tensorboard :
logs = sorted ( [ f ' { self . dataset_dir } /tb_logger/ { d } ' for d in os . listdir ( f ' { self . dataset_dir } /tb_logger/ ' ) if d [ : 6 ] == " events " ] )
2023-03-02 00:46:52 +00:00
if update :
logs = [ logs [ - 1 ] ]
2023-03-01 19:32:11 +00:00
for log in logs :
try :
ea = event_accumulator . EventAccumulator ( log , size_guidance = { event_accumulator . SCALARS : 0 } )
ea . Reload ( )
for key in keys :
scalar = ea . Scalars ( key )
for s in scalar :
2023-03-02 00:46:52 +00:00
if update and s . step < = self . last_info_check_at :
continue
highest_step = max ( highest_step , s . step )
2023-03-05 06:45:07 +00:00
self . statistics . append ( { " step " : s . step , " value " : s . value , " type " : key } )
if key == ' loss_gpt_total ' :
self . losses . append ( { " step " : s . step , " value " : s . value , " type " : key } )
2023-03-01 19:32:11 +00:00
except Exception as e :
pass
2023-02-28 01:01:50 +00:00
2023-03-01 19:32:11 +00:00
else :
logs = sorted ( [ f ' { self . dataset_dir } / { d } ' for d in os . listdir ( self . dataset_dir ) if d [ - 4 : ] == " .log " ] )
2023-03-02 00:46:52 +00:00
if update :
logs = [ logs [ - 1 ] ]
2023-03-01 19:32:11 +00:00
for log in logs :
with open ( log , ' r ' , encoding = " utf-8 " ) as f :
lines = f . readlines ( )
for line in lines :
if line . find ( ' INFO: [epoch: ' ) > = 0 :
# easily rip out our stats...
match = re . findall ( r ' \ b([a-z_0-9]+?) \ b: +?([0-9] \ .[0-9]+?e[+-] \ d+|[ \ d,]+) \ b ' , line )
if not match or len ( match ) == 0 :
continue
info = { }
for k , v in match :
info [ k ] = float ( v . replace ( " , " , " " ) )
if ' iter ' in info :
it = info [ ' iter ' ]
infos [ it ] = info
for k in infos :
if ' loss_gpt_total ' in infos [ k ] :
2023-03-02 00:46:52 +00:00
for key in keys :
if update and int ( k ) < = self . last_info_check_at :
continue
highest_step = max ( highest_step , s . step )
2023-03-05 06:45:07 +00:00
self . statistics . append ( { " step " : int ( k ) , " value " : infos [ k ] [ key ] , " type " : key } )
if key == " loss_gpt_total " :
self . losses . append ( { " step " : int ( k ) , " value " : infos [ k ] [ key ] , " type " : key } )
2023-03-02 00:46:52 +00:00
self . last_info_check_at = highest_step
2023-02-28 01:01:50 +00:00
def cleanup_old ( self , keep = 2 ) :
if keep < = 0 :
return
2023-03-01 01:17:38 +00:00
if not os . path . isdir ( self . dataset_dir ) :
return
2023-02-28 01:01:50 +00:00
models = sorted ( [ int ( d [ : - 8 ] ) for d in os . listdir ( f ' { self . dataset_dir } /models/ ' ) if d [ - 8 : ] == " _gpt.pth " ] )
states = sorted ( [ int ( d [ : - 6 ] ) for d in os . listdir ( f ' { self . dataset_dir } /training_state/ ' ) if d [ - 6 : ] == " .state " ] )
remove_models = models [ : - 2 ]
remove_states = states [ : - 2 ]
for d in remove_models :
path = f ' { self . dataset_dir } /models/ { d } _gpt.pth '
print ( " Removing " , path )
os . remove ( path )
for d in remove_states :
path = f ' { self . dataset_dir } /training_state/ { d } .state '
print ( " Removing " , path )
os . remove ( path )
2023-03-04 17:37:08 +00:00
def parse ( self , line , verbose = False , keep_x_past_datasets = 0 , buffer_size = 8 , progress = None ) :
2023-02-25 16:44:25 +00:00
self . buffer . append ( f ' { line } ' )
2023-02-19 05:05:30 +00:00
2023-03-01 01:17:38 +00:00
should_return = False
2023-03-04 17:37:08 +00:00
percent = 0
message = None
2023-03-01 01:17:38 +00:00
2023-02-19 05:05:30 +00:00
# rip out iteration info
2023-02-23 06:24:54 +00:00
if not self . training_started :
2023-02-19 05:05:30 +00:00
if line . find ( ' Start training from epoch ' ) > = 0 :
2023-03-05 05:17:19 +00:00
self . it_time_start = time . time ( )
2023-02-23 15:31:43 +00:00
self . epoch_time_start = time . time ( )
2023-02-23 06:24:54 +00:00
self . training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
2023-03-02 00:46:52 +00:00
should_return = True
2023-02-23 15:31:43 +00:00
match = re . findall ( r ' epoch: ([ \ d,]+) ' , line )
if match and len ( match ) > 0 :
self . epoch = int ( match [ 0 ] . replace ( " , " , " " ) )
2023-02-22 01:17:09 +00:00
match = re . findall ( r ' iter: ([ \ d,]+) ' , line )
if match and len ( match ) > 0 :
2023-02-23 06:24:54 +00:00
self . it = int ( match [ 0 ] . replace ( " , " , " " ) )
2023-02-27 19:20:06 +00:00
self . checkpoints = int ( ( self . its - self . it ) / self . config [ ' logger ' ] [ ' save_checkpoint_freq ' ] )
2023-02-23 23:22:23 +00:00
else :
2023-02-25 16:44:25 +00:00
lapsed = False
2023-02-20 22:56:39 +00:00
2023-02-27 19:20:06 +00:00
message = None
2023-03-05 07:37:27 +00:00
if line . find ( ' INFO: [epoch: ' ) > = 0 :
2023-03-05 20:42:45 +00:00
info_line = line . split ( " INFO: " ) [ - 1 ]
2023-03-05 07:37:27 +00:00
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
2023-03-05 20:42:45 +00:00
if ' : nan ' in info_line :
2023-03-05 23:55:27 +00:00
self . nan_detected = True
2023-03-05 07:37:27 +00:00
# easily rip out our stats...
2023-03-05 23:55:27 +00:00
match = re . findall ( r ' \ b([a-z_0-9]+?) \ b: *?([0-9] \ .[0-9]+?e[+-] \ d+|[ \ d,]+) \ b ' , info_line )
2023-03-05 07:37:27 +00:00
if match and len ( match ) > 0 :
for k , v in match :
self . info [ k ] = float ( v . replace ( " , " , " " ) )
self . load_losses ( update = True )
should_return = True
2023-03-05 20:42:45 +00:00
if ' epoch ' in self . info :
self . epoch = int ( self . info [ ' epoch ' ] )
if ' iter ' in self . info :
self . it = int ( self . info [ ' iter ' ] )
2023-03-05 07:37:27 +00:00
elif line . find ( ' Saving models and training states ' ) > = 0 :
self . checkpoint = self . checkpoint + 1
percent = self . checkpoint / float ( self . checkpoints )
message = f ' [ { self . checkpoint } / { self . checkpoints } ] Saving checkpoint... '
if progress is not None :
progress ( percent , message )
print ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
self . buffer . append ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
self . cleanup_old ( keep = keep_x_past_datasets )
2023-03-05 20:42:45 +00:00
if line . find ( ' % | ' ) > 0 :
2023-02-25 16:44:25 +00:00
match = re . findall ( r ' ( \ d+) % \ |(.+?) \ | ( \ d+| \ ?) \ /( \ d+| \ ?) \ [(.+?)<(.+?), +(.+?) \ ] ' , line )
2023-02-25 13:55:25 +00:00
if match and len ( match ) > 0 :
match = match [ 0 ]
2023-03-04 17:37:08 +00:00
per_cent = int ( match [ 0 ] ) / 100.0
2023-02-25 13:55:25 +00:00
progressbar = match [ 1 ]
step = int ( match [ 2 ] )
steps = int ( match [ 3 ] )
elapsed = match [ 4 ]
until = match [ 5 ]
rate = match [ 6 ]
2023-02-25 16:44:25 +00:00
last_step = self . last_step
self . last_step = step
if last_step < step :
self . it = self . it + ( step - last_step )
if last_step == step and step == steps :
lapsed = True
self . it_time_end = time . time ( )
self . it_time_delta = self . it_time_end - self . it_time_start
self . it_time_start = time . time ( )
2023-03-04 17:37:08 +00:00
self . it_taken = self . it_taken + 1
2023-03-05 05:17:19 +00:00
if self . it_time_delta :
try :
2023-03-05 14:19:41 +00:00
rate = f ' { " {:.3f} " . format ( self . it_time_delta ) } s/it ' if self . it_time_delta > = 1 or self . it_time_delta == 0 else f ' { " {:.3f} " . format ( 1 / self . it_time_delta ) } it/s '
2023-03-05 05:17:19 +00:00
self . it_rate = rate
except Exception as e :
pass
2023-03-04 17:37:08 +00:00
2023-03-05 18:53:12 +00:00
self . metrics [ ' step ' ] = [ f " { self . epoch } / { self . epochs } " ]
if self . epochs != self . its :
2023-03-05 20:30:27 +00:00
self . metrics [ ' step ' ] . append ( f " { self . it } / { self . its } " )
2023-03-05 18:53:12 +00:00
if steps > 1 :
2023-03-05 20:30:27 +00:00
self . metrics [ ' step ' ] . append ( f " { step } / { steps } " )
2023-03-05 07:37:27 +00:00
self . metrics [ ' step ' ] = " , " . join ( self . metrics [ ' step ' ] )
2023-02-25 13:55:25 +00:00
2023-02-25 16:44:25 +00:00
if lapsed :
self . epoch = self . epoch + 1
self . it = int ( self . epoch * ( self . dataset_size / self . batch_size ) )
self . epoch_time_end = time . time ( )
self . epoch_time_delta = self . epoch_time_end - self . epoch_time_start
self . epoch_time_start = time . time ( )
2023-03-05 14:19:41 +00:00
try :
self . epoch_rate = f ' { " {:.3f} " . format ( self . epoch_time_delta ) } s/epoch ' if self . epoch_time_delta > = 1 or self . epoch_time_delta == 0 else f ' { " {:.3f} " . format ( 1 / self . epoch_time_delta ) } epoch/s ' # I doubt anyone will have it/s rates, but its here
except Exception as e :
pass
2023-02-25 16:44:25 +00:00
#self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
self . epoch_time_deltas = self . epoch_time_deltas + self . epoch_time_delta
self . epoch_taken = self . epoch_taken + 1
self . eta = ( self . epochs - self . epoch ) * ( self . epoch_time_deltas / self . epoch_taken )
try :
eta = str ( timedelta ( seconds = int ( self . eta ) ) )
self . eta_hhmmss = eta
except Exception as e :
pass
2023-02-20 22:56:39 +00:00
2023-03-05 07:37:27 +00:00
self . metrics [ ' rate ' ] = [ ]
if self . epoch_rate :
self . metrics [ ' rate ' ] . append ( self . epoch_rate )
2023-03-05 18:53:12 +00:00
if self . it_rate and self . epoch_rate != self . it_rate :
2023-03-05 07:37:27 +00:00
self . metrics [ ' rate ' ] . append ( self . it_rate )
self . metrics [ ' rate ' ] = " , " . join ( self . metrics [ ' rate ' ] )
2023-03-05 05:17:19 +00:00
2023-03-05 07:37:27 +00:00
eta_hhmmss = " ? "
if self . eta_hhmmss :
eta_hhmmss = self . eta_hhmmss
else :
try :
eta = ( self . its - self . it ) * ( self . it_time_deltas / self . it_taken )
eta = str ( timedelta ( seconds = int ( eta ) ) )
eta_hhmmss = eta
except Exception as e :
pass
self . metrics [ ' loss ' ] = [ ]
2023-03-05 18:53:12 +00:00
if ' learning_rate_gpt_0 ' in self . info :
2023-03-05 20:30:27 +00:00
self . metrics [ ' loss ' ] . append ( f ' LR: { " {:.3e} " . format ( self . info [ " learning_rate_gpt_0 " ] ) } ' )
2023-03-05 18:53:12 +00:00
2023-03-05 07:37:27 +00:00
if len ( self . losses ) > 0 :
2023-03-05 13:24:07 +00:00
self . metrics [ ' loss ' ] . append ( f ' Loss: { " {:.3f} " . format ( self . losses [ - 1 ] [ " value " ] ) } ' )
2023-03-05 07:37:27 +00:00
if len ( self . losses ) > = 2 :
2023-03-05 14:39:24 +00:00
# """riemann sum""" but not really as this is for derivatives and not integrals
deriv = 0
accum_length = len ( self . losses ) / / 2 # i *guess* this is fine when you think about it
2023-03-05 19:58:15 +00:00
loss_value = self . losses [ - 1 ] [ " value " ]
2023-03-05 20:13:39 +00:00
2023-03-05 14:39:24 +00:00
for i in range ( accum_length ) :
2023-03-05 19:58:15 +00:00
d1_loss = self . losses [ accum_length - i - 1 ] [ " value " ]
d2_loss = self . losses [ accum_length - i - 2 ] [ " value " ]
2023-03-05 14:39:24 +00:00
dloss = ( d2_loss - d1_loss )
2023-03-05 19:58:15 +00:00
d1_step = self . losses [ accum_length - i - 1 ] [ " step " ]
d2_step = self . losses [ accum_length - i - 2 ] [ " step " ]
2023-03-05 14:39:24 +00:00
dstep = ( d2_step - d1_step )
if dstep == 0 :
continue
2023-03-05 07:37:27 +00:00
inst_deriv = dloss / dstep
2023-03-05 14:39:24 +00:00
deriv + = inst_deriv
deriv = deriv / accum_length
2023-03-05 07:37:27 +00:00
2023-03-05 14:39:24 +00:00
if deriv != 0 : # dloss < 0:
2023-03-05 07:37:27 +00:00
next_milestone = None
for milestone in self . loss_milestones :
2023-03-05 20:42:45 +00:00
if loss_value > milestone :
2023-03-05 07:37:27 +00:00
next_milestone = milestone
break
2023-03-05 20:42:45 +00:00
2023-03-05 07:37:27 +00:00
if next_milestone :
# tfw can do simple calculus but not basic algebra in my head
2023-03-05 19:58:15 +00:00
est_its = ( next_milestone - loss_value ) / deriv
2023-03-05 14:39:24 +00:00
if est_its > = 0 :
self . metrics [ ' loss ' ] . append ( f ' Est. milestone { next_milestone } in: { int ( est_its ) } its ' )
2023-03-05 07:37:27 +00:00
else :
2023-03-05 19:58:15 +00:00
est_loss = inst_deriv * ( self . its - self . it ) + loss_value
2023-03-05 14:39:24 +00:00
if est_loss > = 0 :
self . metrics [ ' loss ' ] . append ( f ' Est. final loss: { " {:.3f} " . format ( est_loss ) } ' )
2023-03-02 00:46:52 +00:00
2023-03-05 07:37:27 +00:00
self . metrics [ ' loss ' ] = " , " . join ( self . metrics [ ' loss ' ] )
2023-03-02 00:46:52 +00:00
2023-03-05 18:53:12 +00:00
message = f " [ { self . metrics [ ' step ' ] } ] [ { self . metrics [ ' rate ' ] } ] [ETA: { eta_hhmmss } ] \n [ { self . metrics [ ' loss ' ] } ] "
2023-03-05 23:55:27 +00:00
if self . nan_detected :
message = f " [!NaN DETECTED!] { message } "
2023-02-25 16:44:25 +00:00
2023-03-05 07:37:27 +00:00
if message :
percent = self . it / float ( self . its ) # self.epoch / float(self.epochs)
2023-02-23 23:22:23 +00:00
if progress is not None :
progress ( percent , message )
2023-03-05 07:37:27 +00:00
self . buffer . append ( f ' [ { " {:.3f} " . format ( percent * 100 ) } %] { message } ' )
2023-02-28 01:01:50 +00:00
2023-03-01 01:17:38 +00:00
if verbose and not self . training_started :
should_return = True
2023-02-25 16:44:25 +00:00
self . buffer = self . buffer [ - buffer_size : ]
2023-03-04 17:37:08 +00:00
result = None
2023-03-01 01:17:38 +00:00
if should_return :
2023-03-04 17:37:08 +00:00
result = " " . join ( self . buffer ) if not self . training_started else message
return (
result ,
percent ,
message ,
)
2023-02-19 07:05:11 +00:00
2023-03-04 17:37:08 +00:00
def run_training ( config_path , verbose = False , gpus = 1 , keep_x_past_datasets = 0 , progress = gr . Progress ( track_tqdm = True ) ) :
2023-02-23 06:24:54 +00:00
global training_state
if training_state and training_state . process :
return " Training already in progress "
2023-03-07 02:47:10 +00:00
# ensure we have the dvae.pth
get_model_path ( ' dvae.pth ' )
2023-02-23 06:24:54 +00:00
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
torch . multiprocessing . freeze_support ( )
unload_tts ( )
unload_whisper ( )
unload_voicefixer ( )
2023-03-03 04:37:18 +00:00
training_state = TrainingState ( config_path = config_path , keep_x_past_datasets = keep_x_past_datasets , gpus = gpus )
2023-02-18 02:07:22 +00:00
2023-02-23 06:24:54 +00:00
for line in iter ( training_state . process . stdout . readline , " " ) :
2023-03-04 20:42:54 +00:00
if training_state . killed :
return
2023-03-04 17:37:08 +00:00
result , percent , message = training_state . parse ( line = line , verbose = verbose , keep_x_past_datasets = keep_x_past_datasets , progress = progress )
2023-02-23 23:22:23 +00:00
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
2023-03-04 17:37:08 +00:00
if result :
yield result
2023-02-23 06:24:54 +00:00
2023-03-05 05:17:19 +00:00
if progress is not None and message :
progress ( percent , message )
2023-02-24 23:13:13 +00:00
if training_state :
training_state . process . stdout . close ( )
return_code = training_state . process . wait ( )
training_state = None
2023-02-18 02:07:22 +00:00
2023-03-02 01:35:12 +00:00
def update_training_dataplot ( config_path = None ) :
2023-02-28 01:01:50 +00:00
global training_state
2023-03-02 01:35:12 +00:00
update = None
if not training_state :
2023-03-03 04:37:18 +00:00
if config_path :
training_state = TrainingState ( config_path = config_path , start = False )
2023-03-05 06:45:07 +00:00
if training_state . statistics :
2023-03-06 00:44:29 +00:00
update = gr . LinePlot . update ( value = pd . DataFrame ( training_state . statistics ) , x_lim = [ 0 , training_state . its ] , x = " step " , y = " value " , title = " Training Metrics " , color = " type " , tooltip = [ ' step ' , ' value ' , ' type ' ] , width = 600 , height = 350 , )
2023-03-03 04:37:18 +00:00
del training_state
training_state = None
2023-03-05 06:45:07 +00:00
elif training_state . statistics :
2023-03-04 15:55:06 +00:00
training_state . load_losses ( )
2023-03-06 00:44:29 +00:00
update = gr . LinePlot . update ( value = pd . DataFrame ( training_state . statistics ) , x_lim = [ 0 , training_state . its ] , x = " step " , y = " value " , title = " Training Metrics " , color = " type " , tooltip = [ ' step ' , ' value ' , ' type ' ] , width = 600 , height = 350 , )
2023-03-02 01:35:12 +00:00
return update
2023-02-28 01:01:50 +00:00
2023-03-04 17:37:08 +00:00
def reconnect_training ( verbose = False , progress = gr . Progress ( track_tqdm = True ) ) :
2023-02-23 06:24:54 +00:00
global training_state
if not training_state or not training_state . process :
return " Training not in progress "
for line in iter ( training_state . process . stdout . readline , " " ) :
2023-03-04 17:37:08 +00:00
result , percent , message = training_state . parse ( line = line , verbose = verbose , keep_x_past_datasets = keep_x_past_datasets , progress = progress )
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
if result :
yield result
2023-02-19 05:05:30 +00:00
2023-03-05 05:17:19 +00:00
if progress is not None and message :
progress ( percent , message )
2023-02-18 02:07:22 +00:00
def stop_training ( ) :
2023-02-24 23:13:13 +00:00
global training_state
if training_state is None :
2023-02-18 02:07:22 +00:00
return " No training in progress "
2023-02-24 23:13:13 +00:00
print ( " Killing training process... " )
training_state . killed = True
2023-03-04 20:42:54 +00:00
children = [ ]
# wrapped in a try/catch in case for some reason this fails outside of Linux
try :
children = [ p . info for p in psutil . process_iter ( attrs = [ ' pid ' , ' name ' , ' cmdline ' ] ) if ' ./src/train.py ' in p . info [ ' cmdline ' ] ]
except Exception as e :
pass
2023-02-24 23:13:13 +00:00
training_state . process . stdout . close ( )
2023-03-04 20:42:54 +00:00
training_state . process . terminate ( )
training_state . process . kill ( )
2023-02-24 23:13:13 +00:00
return_code = training_state . process . wait ( )
2023-03-04 20:42:54 +00:00
for p in children :
os . kill ( p [ ' pid ' ] , signal . SIGKILL )
2023-02-24 23:13:13 +00:00
training_state = None
2023-03-04 17:37:08 +00:00
print ( " Killed training process. " )
2023-03-01 01:17:38 +00:00
return f " Training cancelled: { return_code } "
2023-02-17 19:06:05 +00:00
2023-02-21 21:50:05 +00:00
def get_halfp_model_path ( ) :
2023-02-21 19:31:57 +00:00
autoregressive_model_path = get_model_path ( ' autoregressive.pth ' )
return autoregressive_model_path . replace ( " .pth " , " _half.pth " )
2023-02-21 17:35:30 +00:00
def convert_to_halfp ( ) :
autoregressive_model_path = get_model_path ( ' autoregressive.pth ' )
2023-02-21 19:31:57 +00:00
print ( f ' Converting model to half precision: { autoregressive_model_path } ' )
2023-02-21 17:35:30 +00:00
model = torch . load ( autoregressive_model_path )
for k in model :
2023-02-21 19:31:57 +00:00
model [ k ] = model [ k ] . half ( )
2023-02-21 17:35:30 +00:00
2023-02-21 21:50:05 +00:00
outfile = get_halfp_model_path ( )
2023-02-21 19:31:57 +00:00
torch . save ( model , outfile )
print ( f ' Converted model to half precision: { outfile } ' )
2023-02-21 17:35:30 +00:00
2023-02-27 19:20:06 +00:00
def whisper_transcribe ( file , language = None ) :
# shouldn't happen, but it's for safety
if not whisper_model :
2023-03-05 05:17:19 +00:00
load_whisper_model ( language = language )
2023-02-27 19:20:06 +00:00
2023-03-06 05:21:33 +00:00
if args . whisper_backend == " openai/whisper " :
2023-03-05 05:22:35 +00:00
if not language :
language = None
2023-03-05 06:45:07 +00:00
2023-03-05 05:17:19 +00:00
return whisper_model . transcribe ( file , language = language )
2023-02-27 19:20:06 +00:00
2023-03-06 05:21:33 +00:00
elif args . whisper_backend == " lightmare/whispercpp " :
res = whisper_model . transcribe ( file )
segments = whisper_model . extract_text_and_timestamps ( res )
2023-02-27 19:20:06 +00:00
2023-03-06 05:21:33 +00:00
result = {
' segments ' : [ ]
2023-02-27 19:20:06 +00:00
}
2023-03-06 05:21:33 +00:00
for segment in segments :
reparsed = {
' start ' : segment [ 0 ] / 100.0 ,
' end ' : segment [ 1 ] / 100.0 ,
' text ' : segment [ 2 ] ,
}
result [ ' segments ' ] . append ( reparsed )
return result
2023-03-05 17:54:36 +00:00
2023-03-06 05:21:33 +00:00
# credit to https://git.ecker.tech/yqxtqymn for the busywork of getting this added
elif args . whisper_backend == " m-bain/whisperx " :
import whisperx
device = " cuda " if get_device_name ( ) == " cuda " else " cpu "
result = whisper_model . transcribe ( file )
model_a , metadata = whisperx . load_align_model ( language_code = result [ " language " ] , device = device )
result_aligned = whisperx . align ( result [ " segments " ] , model_a , metadata , file , device )
2023-02-27 19:20:06 +00:00
2023-03-06 11:01:33 +00:00
for i in range ( len ( result_aligned [ ' segments ' ] ) ) :
del result_aligned [ ' segments ' ] [ i ] [ ' word-segments ' ]
del result_aligned [ ' segments ' ] [ i ] [ ' char-segments ' ]
result [ ' segments ' ] = result_aligned [ ' segments ' ]
2023-03-06 05:21:33 +00:00
return result
2023-02-27 19:20:06 +00:00
2023-03-06 10:47:06 +00:00
def prepare_dataset ( files , outdir , language = None , skip_existings = False , progress = None ) :
2023-02-20 00:21:16 +00:00
unload_tts ( )
2023-02-18 20:37:37 +00:00
2023-02-20 00:21:16 +00:00
global whisper_model
if whisper_model is None :
2023-02-27 19:20:06 +00:00
load_whisper_model ( language = language )
2023-02-17 19:06:05 +00:00
2023-02-20 00:21:16 +00:00
os . makedirs ( outdir , exist_ok = True )
2023-02-17 00:08:27 +00:00
2023-02-20 00:21:16 +00:00
results = { }
transcription = [ ]
2023-03-06 16:39:37 +00:00
files = sorted ( files )
2023-02-19 05:10:08 +00:00
2023-03-06 10:47:06 +00:00
previous_list = [ ]
if skip_existings and os . path . exists ( f ' { outdir } /train.txt ' ) :
2023-03-06 21:48:34 +00:00
parsed_list = [ ]
2023-03-06 10:47:06 +00:00
with open ( f ' { outdir } /train.txt ' , ' r ' , encoding = " utf-8 " ) as f :
parsed_list = f . readlines ( )
for line in parsed_list :
match = re . findall ( r " ^(.+?)_ \ d+ \ .wav$ " , line . split ( " | " ) [ 0 ] )
print ( match )
if match is None or len ( match ) == 0 :
continue
if match [ 0 ] not in previous_list :
previous_list . append ( f ' { match [ 0 ] } .wav ' )
2023-02-20 00:21:16 +00:00
for file in enumerate_progress ( files , desc = " Iterating through voice files " , progress = progress ) :
2023-03-05 17:54:36 +00:00
basename = os . path . basename ( file )
2023-03-06 10:47:06 +00:00
if basename in previous_list :
print ( f " Skipping already parsed file: { basename } " )
continue
2023-03-05 05:17:19 +00:00
result = whisper_transcribe ( file , language = language )
2023-03-05 17:54:36 +00:00
results [ basename ] = result
2023-02-20 00:21:16 +00:00
print ( f " Transcribed file: { file } , { len ( result [ ' segments ' ] ) } found. " )
2023-02-18 20:37:37 +00:00
2023-02-20 00:21:16 +00:00
waveform , sampling_rate = torchaudio . load ( file )
num_channels , num_frames = waveform . shape
2023-02-18 20:37:37 +00:00
2023-03-05 17:54:36 +00:00
idx = 0
2023-02-20 00:21:16 +00:00
for segment in result [ ' segments ' ] : # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int ( segment [ ' start ' ] * sampling_rate )
end = int ( segment [ ' end ' ] * sampling_rate )
2023-02-18 20:37:37 +00:00
2023-02-20 00:21:16 +00:00
sliced_waveform = waveform [ : , start : end ]
2023-03-05 17:54:36 +00:00
sliced_name = basename . replace ( " .wav " , f " _ { pad ( idx , 4 ) } .wav " )
if not torch . any ( sliced_waveform < 0 ) :
print ( f " Error with { sliced_name } , skipping... " )
continue
2023-02-18 20:37:37 +00:00
2023-02-20 00:21:16 +00:00
torchaudio . save ( f " { outdir } / { sliced_name } " , sliced_waveform , sampling_rate )
2023-02-18 20:37:37 +00:00
2023-02-20 00:21:16 +00:00
idx = idx + 1
2023-03-03 07:23:10 +00:00
line = f " { sliced_name } | { segment [ ' text ' ] . strip ( ) } "
transcription . append ( line )
with open ( f ' { outdir } /train.txt ' , ' a ' , encoding = " utf-8 " ) as f :
2023-03-06 21:48:34 +00:00
f . write ( f ' \n { line } ' )
2023-03-06 16:39:37 +00:00
do_gc ( )
2023-02-20 00:21:16 +00:00
with open ( f ' { outdir } /whisper.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( results , indent = ' \t ' ) )
2023-03-06 16:50:55 +00:00
2023-02-20 00:21:16 +00:00
unload_whisper ( )
2023-02-18 20:37:37 +00:00
2023-03-07 03:01:02 +00:00
joined = " \n " . join ( transcription )
2023-03-07 05:43:26 +00:00
if not skip_existings :
with open ( f ' { outdir } /train.txt ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( joined )
2023-03-05 05:22:35 +00:00
return f " Processed dataset to: { outdir } \n { joined } "
2023-02-18 20:37:37 +00:00
2023-02-19 20:22:03 +00:00
def calc_iterations ( epochs , lines , batch_size ) :
iterations = int ( epochs * lines / float ( batch_size ) )
return iterations
2023-03-04 04:41:56 +00:00
def schedule_learning_rate ( iterations , schedule = EPOCH_SCHEDULE ) :
return [ int ( iterations * d ) for d in schedule ]
2023-02-19 20:22:03 +00:00
2023-03-05 05:17:19 +00:00
def optimize_training_settings ( epochs , learning_rate , text_ce_lr_weight , learning_rate_schedule , batch_size , gradient_accumulation_size , print_rate , save_rate , resume_path , half_p , bnb , workers , source_model , voice ) :
2023-02-19 20:22:03 +00:00
name = f " { voice } -finetune "
dataset_name = f " { voice } -train "
dataset_path = f " ./training/ { voice } /train.txt "
validation_name = f " { voice } -val "
validation_path = f " ./training/ { voice } /train.txt "
with open ( dataset_path , ' r ' , encoding = " utf-8 " ) as f :
lines = len ( f . readlines ( ) )
messages = [ ]
if batch_size > lines :
batch_size = lines
messages . append ( f " Batch size is larger than your dataset, clamping batch size to: { batch_size } " )
2023-02-19 21:06:14 +00:00
if batch_size % lines != 0 :
nearest_slice = int ( lines / batch_size ) + 1
batch_size = int ( lines / nearest_slice )
messages . append ( f " Batch size not neatly divisible by dataset size, adjusting batch size to: { batch_size } ( { nearest_slice } steps per epoch) " )
2023-02-21 20:20:52 +00:00
2023-03-04 15:55:06 +00:00
if gradient_accumulation_size == 0 :
gradient_accumulation_size = 1
2023-03-04 17:37:08 +00:00
if batch_size / gradient_accumulation_size < 2 :
gradient_accumulation_size = int ( batch_size / 2 )
if gradient_accumulation_size == 0 :
gradient_accumulation_size = 1
messages . append ( f " Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: { gradient_accumulation_size } " )
2023-03-04 15:55:06 +00:00
elif batch_size % gradient_accumulation_size != 0 :
gradient_accumulation_size = int ( batch_size / gradient_accumulation_size )
2023-03-04 17:37:08 +00:00
if gradient_accumulation_size == 0 :
gradient_accumulation_size = 1
2023-03-04 15:55:06 +00:00
messages . append ( f " Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: { gradient_accumulation_size } " )
2023-02-19 21:06:14 +00:00
2023-02-19 20:22:03 +00:00
iterations = calc_iterations ( epochs = epochs , lines = lines , batch_size = batch_size )
2023-02-19 20:38:00 +00:00
if epochs < print_rate :
print_rate = epochs
2023-02-19 20:22:03 +00:00
messages . append ( f " Print rate is too small for the given iteration step, clamping print rate to: { print_rate } " )
2023-02-19 20:38:00 +00:00
if epochs < save_rate :
save_rate = epochs
2023-02-19 20:22:03 +00:00
messages . append ( f " Save rate is too small for the given iteration step, clamping save rate to: { save_rate } " )
if resume_path and not os . path . exists ( resume_path ) :
resume_path = None
messages . append ( " Resume path specified, but does not exist. Disabling... " )
2023-02-26 01:57:56 +00:00
if bnb :
messages . append ( " BitsAndBytes requested. Please note this is ! EXPERIMENTAL ! " )
2023-02-21 19:31:57 +00:00
2023-03-04 15:55:06 +00:00
if half_p :
if bnb :
half_p = False
messages . append ( " Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision... " )
else :
messages . append ( " Half Precision requested. Please note this is ! EXPERIMENTAL ! " )
if not os . path . exists ( get_halfp_model_path ( ) ) :
convert_to_halfp ( )
2023-02-19 21:06:14 +00:00
messages . append ( f " For { epochs } epochs with { lines } lines in batches of { batch_size } , iterating for { iterations } steps ( { int ( iterations / epochs ) } steps per epoch) " )
2023-02-19 20:22:03 +00:00
return (
learning_rate ,
2023-03-01 01:17:38 +00:00
text_ce_lr_weight ,
2023-02-19 20:22:03 +00:00
learning_rate_schedule ,
2023-02-23 23:22:23 +00:00
batch_size ,
2023-03-04 15:55:06 +00:00
gradient_accumulation_size ,
2023-02-19 20:22:03 +00:00
print_rate ,
save_rate ,
resume_path ,
messages
)
2023-03-05 05:17:19 +00:00
def save_training_settings ( iterations = None , learning_rate = None , text_ce_lr_weight = None , learning_rate_schedule = None , batch_size = None , gradient_accumulation_size = None , print_rate = None , save_rate = None , name = None , dataset_name = None , dataset_path = None , validation_name = None , validation_path = None , output_name = None , resume_path = None , half_p = None , bnb = None , workers = None , source_model = None ) :
2023-03-01 01:17:38 +00:00
if not source_model :
source_model = f " ./models/tortoise/autoregressive { ' _half ' if half_p else ' ' } .pth "
2023-02-17 03:05:27 +00:00
settings = {
2023-02-18 15:50:51 +00:00
" iterations " : iterations if iterations else 500 ,
" batch_size " : batch_size if batch_size else 64 ,
2023-02-17 03:05:27 +00:00
" learning_rate " : learning_rate if learning_rate else 1e-5 ,
2023-02-19 21:06:14 +00:00
" gen_lr_steps " : learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE ,
2023-03-04 15:55:06 +00:00
" gradient_accumulation_size " : gradient_accumulation_size if gradient_accumulation_size else 4 ,
" print_rate " : print_rate if print_rate else 1 ,
2023-02-17 03:05:27 +00:00
" save_rate " : save_rate if save_rate else 50 ,
" name " : name if name else " finetune " ,
" dataset_name " : dataset_name if dataset_name else " finetune " ,
2023-02-17 13:57:03 +00:00
" dataset_path " : dataset_path if dataset_path else " ./training/finetune/train.txt " ,
2023-02-17 03:05:27 +00:00
" validation_name " : validation_name if validation_name else " finetune " ,
2023-02-17 13:57:03 +00:00
" validation_path " : validation_path if validation_path else " ./training/finetune/train.txt " ,
2023-02-19 16:16:44 +00:00
2023-03-01 01:17:38 +00:00
" text_ce_lr_weight " : text_ce_lr_weight if text_ce_lr_weight else 0.01 ,
2023-02-21 19:31:57 +00:00
' resume_state ' : f " resume_state: ' { resume_path } ' " ,
2023-03-01 01:17:38 +00:00
' pretrain_model_gpt ' : f " pretrain_model_gpt: ' { source_model } ' " ,
2023-02-21 19:31:57 +00:00
2023-02-26 01:57:56 +00:00
' float16 ' : ' true ' if half_p else ' false ' ,
' bitsandbytes ' : ' true ' if bnb else ' false ' ,
2023-03-05 05:17:19 +00:00
' workers ' : workers if workers else 2 ,
2023-02-17 03:05:27 +00:00
}
2023-02-18 14:51:00 +00:00
2023-02-21 19:31:57 +00:00
if resume_path :
settings [ ' pretrain_model_gpt ' ] = f " # { settings [ ' pretrain_model_gpt ' ] } "
else :
settings [ ' resume_state ' ] = f " # resume_state: ' ./training/ { name if name else ' finetune ' } /training_state/#.state ' "
if half_p :
2023-02-22 13:24:03 +00:00
if not os . path . exists ( get_halfp_model_path ( ) ) :
2023-02-21 19:31:57 +00:00
convert_to_halfp ( )
2023-02-19 21:06:14 +00:00
2023-02-18 14:51:00 +00:00
if not output_name :
output_name = f ' { settings [ " name " ] } .yaml '
2023-02-17 03:05:27 +00:00
2023-02-18 02:07:22 +00:00
with open ( f ' ./models/.template.yaml ' , ' r ' , encoding = " utf-8 " ) as f :
2023-02-17 03:05:27 +00:00
yaml = f . read ( )
2023-02-19 16:16:44 +00:00
# i could just load and edit the YAML directly, but this is easier, as I don't need to bother with path traversals
2023-02-17 03:05:27 +00:00
for k in settings :
2023-02-19 16:16:44 +00:00
if settings [ k ] is None :
continue
2023-02-17 03:05:27 +00:00
yaml = yaml . replace ( f " $ {{ { k } }} " , str ( settings [ k ] ) )
2023-02-18 02:07:22 +00:00
2023-02-19 16:16:44 +00:00
outfile = f ' ./training/ { output_name } '
2023-02-18 02:07:22 +00:00
with open ( outfile , ' w ' , encoding = " utf-8 " ) as f :
2023-02-17 03:05:27 +00:00
f . write ( yaml )
2023-02-18 02:07:22 +00:00
return f " Training settings saved to: { outfile } "
def import_voices ( files , saveAs = None , progress = None ) :
2023-02-17 03:05:27 +00:00
global args
2023-02-18 02:07:22 +00:00
if not isinstance ( files , list ) :
files = [ files ]
2023-02-17 03:05:27 +00:00
2023-02-18 02:07:22 +00:00
for file in enumerate_progress ( files , desc = " Importing voice files " , progress = progress ) :
j , latents = read_generate_settings ( file , read_latents = True )
if j is not None and saveAs is None :
saveAs = j [ ' voice ' ]
if saveAs is None or saveAs == " " :
raise Exception ( " Specify a voice name " )
outdir = f ' { get_voice_dir ( ) } / { saveAs } / '
os . makedirs ( outdir , exist_ok = True )
if latents :
print ( f " Importing latents to { latents } " )
with open ( f ' { outdir } /cond_latents.pth ' , ' wb ' ) as f :
f . write ( latents )
latents = f ' { outdir } /cond_latents.pth '
print ( f " Imported latents to { latents } " )
2023-02-17 03:05:27 +00:00
else :
2023-02-18 02:07:22 +00:00
filename = file . name
if filename [ - 4 : ] != " .wav " :
raise Exception ( " Please convert to a WAV first " )
path = f " { outdir } / { os . path . basename ( filename ) } "
print ( f " Importing voice to { path } " )
waveform , sampling_rate = torchaudio . load ( filename )
2023-02-20 00:21:16 +00:00
if args . voice_fixer :
if not voicefixer :
load_voicefixer ( )
2023-02-18 02:07:22 +00:00
# resample to best bandwidth since voicefixer will do it anyways through librosa
if sampling_rate != 44100 :
print ( f " Resampling imported voice sample: { path } " )
resampler = torchaudio . transforms . Resample (
sampling_rate ,
44100 ,
lowpass_filter_width = 16 ,
rolloff = 0.85 ,
resampling_method = " kaiser_window " ,
beta = 8.555504641634386 ,
)
waveform = resampler ( waveform )
sampling_rate = 44100
torchaudio . save ( path , waveform , sampling_rate )
print ( f " Running ' voicefixer ' on voice sample: { path } " )
voicefixer . restore (
input = path ,
output = path ,
cuda = get_device_name ( ) == " cuda " and args . voice_fixer_use_cuda ,
#mode=mode,
)
else :
torchaudio . save ( path , waveform , sampling_rate )
2023-02-17 03:05:27 +00:00
2023-02-18 02:07:22 +00:00
print ( f " Imported voice to { path } " )
2023-02-17 03:05:27 +00:00
2023-02-21 21:50:05 +00:00
def get_voice_list ( dir = get_voice_dir ( ) , append_defaults = False ) :
2023-02-20 00:21:16 +00:00
os . makedirs ( dir , exist_ok = True )
2023-02-21 21:50:05 +00:00
res = sorted ( [ d for d in os . listdir ( dir ) if os . path . isdir ( os . path . join ( dir , d ) ) and len ( os . listdir ( os . path . join ( dir , d ) ) ) > 0 ] )
if append_defaults :
res = res + [ " random " , " microphone " ]
return res
2023-02-17 03:05:27 +00:00
2023-02-28 15:36:06 +00:00
def get_autoregressive_models ( dir = " ./models/finetunes/ " , prefixed = False ) :
2023-02-20 00:21:16 +00:00
os . makedirs ( dir , exist_ok = True )
2023-02-21 21:50:05 +00:00
base = [ get_model_path ( ' autoregressive.pth ' ) ]
halfp = get_halfp_model_path ( )
if os . path . exists ( halfp ) :
base . append ( halfp )
2023-02-24 12:58:41 +00:00
additionals = sorted ( [ f ' { dir } / { d } ' for d in os . listdir ( dir ) if d [ - 4 : ] == " .pth " ] )
found = [ ]
for training in os . listdir ( f ' ./training/ ' ) :
if not os . path . isdir ( f ' ./training/ { training } / ' ) or not os . path . isdir ( f ' ./training/ { training } /models/ ' ) :
continue
models = sorted ( [ int ( d [ : - 8 ] ) for d in os . listdir ( f ' ./training/ { training } /models/ ' ) if d [ - 8 : ] == " _gpt.pth " ] )
2023-02-24 13:05:08 +00:00
found = found + [ f ' ./training/ { training } /models/ { d } _gpt.pth ' for d in models ]
2023-02-24 12:58:41 +00:00
2023-03-07 04:34:39 +00:00
if len ( found ) > 0 or len ( additionals ) > 0 :
base = [ " auto " ] + base
2023-02-28 15:36:06 +00:00
res = base + additionals + found
if prefixed :
for i in range ( len ( res ) ) :
path = res [ i ]
hash = hash_file ( path )
shorthash = hash [ : 8 ]
res [ i ] = f ' [ { shorthash } ] { path } '
return res
2023-02-20 00:21:16 +00:00
def get_dataset_list ( dir = " ./training/ " ) :
return sorted ( [ d for d in os . listdir ( dir ) if os . path . isdir ( os . path . join ( dir , d ) ) and len ( os . listdir ( os . path . join ( dir , d ) ) ) > 0 and " train.txt " in os . listdir ( os . path . join ( dir , d ) ) ] )
def get_training_list ( dir = " ./training/ " ) :
return sorted ( [ f ' ./training/ { d } /train.yaml ' for d in os . listdir ( dir ) if os . path . isdir ( os . path . join ( dir , d ) ) and len ( os . listdir ( os . path . join ( dir , d ) ) ) > 0 and " train.yaml " in os . listdir ( os . path . join ( dir , d ) ) ] )
def do_gc ( ) :
gc . collect ( )
try :
2023-03-05 13:24:07 +00:00
torch . cuda . empty_cache ( )
2023-02-20 00:21:16 +00:00
except Exception as e :
pass
def pad ( num , zeroes ) :
return str ( num ) . zfill ( zeroes + 1 )
2023-02-17 03:05:27 +00:00
def curl ( url ) :
try :
req = urllib . request . Request ( url , headers = { ' User-Agent ' : ' Python ' } )
conn = urllib . request . urlopen ( req )
data = conn . read ( )
data = data . decode ( )
data = json . loads ( data )
conn . close ( )
return data
except Exception as e :
print ( e )
return None
def check_for_updates ( ) :
if not os . path . isfile ( ' ./.git/FETCH_HEAD ' ) :
print ( " Cannot check for updates: not from a git repo " )
return False
with open ( f ' ./.git/FETCH_HEAD ' , ' r ' , encoding = " utf-8 " ) as f :
head = f . read ( )
match = re . findall ( r " ^([a-f0-9]+).+?https: \ / \ /(.+?) \ /(.+?) \ /(.+?) \ n " , head )
if match is None or len ( match ) == 0 :
print ( " Cannot check for updates: cannot parse FETCH_HEAD " )
return False
match = match [ 0 ]
local = match [ 0 ]
host = match [ 1 ]
owner = match [ 2 ]
repo = match [ 3 ]
res = curl ( f " https:// { host } /api/v1/repos/ { owner } / { repo } /branches/ " ) #this only works for gitea instances
if res is None or len ( res ) == 0 :
print ( " Cannot check for updates: cannot fetch from remote " )
return False
remote = res [ 0 ] [ " commit " ] [ " id " ]
if remote != local :
print ( f " New version found: { local [ : 8 ] } => { remote [ : 8 ] } " )
return True
return False
2023-02-20 00:21:16 +00:00
def enumerate_progress ( iterable , desc = None , progress = None , verbose = None ) :
if verbose and desc is not None :
print ( desc )
2023-02-17 03:05:27 +00:00
2023-02-20 00:21:16 +00:00
if progress is None :
return tqdm ( iterable , disable = not verbose )
return progress . tqdm ( iterable , desc = f ' { progress . msg_prefix } { desc } ' if hasattr ( progress , ' msg_prefix ' ) else desc , track_tqdm = True )
2023-02-17 03:05:27 +00:00
2023-02-20 00:21:16 +00:00
def notify_progress ( message , progress = None , verbose = True ) :
if verbose :
print ( message )
2023-02-17 03:05:27 +00:00
2023-02-20 00:21:16 +00:00
if progress is None :
return
2023-02-18 14:10:26 +00:00
2023-02-20 00:21:16 +00:00
progress ( 0 , desc = message )
2023-02-18 14:10:26 +00:00
2023-02-20 00:21:16 +00:00
def get_args ( ) :
global args
return args
2023-02-18 14:10:26 +00:00
2023-02-20 00:21:16 +00:00
def setup_args ( ) :
global args
2023-02-19 01:47:06 +00:00
2023-02-20 00:21:16 +00:00
default_arguments = {
' share ' : False ,
' listen ' : None ,
' check-for-updates ' : False ,
' models-from-local-only ' : False ,
' low-vram ' : False ,
' sample-batch-size ' : None ,
' embed-output-metadata ' : True ,
' latents-lean-and-mean ' : True ,
' voice-fixer ' : False , # getting tired of long initialization times in a Colab for downloading a large dataset for it
' voice-fixer-use-cuda ' : True ,
' force-cpu-for-conditioning-latents ' : False ,
' defer-tts-load ' : False ,
' device-override ' : None ,
2023-02-21 21:50:05 +00:00
' prune-nonfinal-outputs ' : True ,
2023-03-07 02:45:22 +00:00
' vocoder-model ' : VOCODERS [ - 1 ] ,
2023-02-20 00:21:16 +00:00
' concurrency-count ' : 2 ,
2023-03-05 23:55:27 +00:00
' autocalculate-voice-chunk-duration-size ' : 0 ,
2023-02-20 00:21:16 +00:00
' output-sample-rate ' : 44100 ,
' output-volume ' : 1 ,
2023-02-27 19:20:06 +00:00
' autoregressive-model ' : None ,
2023-03-06 05:21:33 +00:00
' whisper-backend ' : ' openai/whisper ' ,
2023-02-27 19:20:06 +00:00
' whisper-model ' : " base " ,
2023-02-26 01:57:56 +00:00
' training-default-halfp ' : False ,
' training-default-bnb ' : True ,
2023-02-20 00:21:16 +00:00
}
2023-02-19 01:47:06 +00:00
2023-02-20 00:21:16 +00:00
if os . path . isfile ( ' ./config/exec.json ' ) :
with open ( f ' ./config/exec.json ' , ' r ' , encoding = " utf-8 " ) as f :
2023-02-28 22:13:21 +00:00
try :
overrides = json . load ( f )
for k in overrides :
default_arguments [ k ] = overrides [ k ]
except Exception as e :
print ( e )
pass
2023-02-19 16:24:06 +00:00
2023-02-20 00:21:16 +00:00
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --share " , action = ' store_true ' , default = default_arguments [ ' share ' ] , help = " Lets Gradio return a public URL to use anywhere " )
parser . add_argument ( " --listen " , default = default_arguments [ ' listen ' ] , help = " Path for Gradio to listen on " )
parser . add_argument ( " --check-for-updates " , action = ' store_true ' , default = default_arguments [ ' check-for-updates ' ] , help = " Checks for update on startup " )
parser . add_argument ( " --models-from-local-only " , action = ' store_true ' , default = default_arguments [ ' models-from-local-only ' ] , help = " Only loads models from disk, does not check for updates for models " )
parser . add_argument ( " --low-vram " , action = ' store_true ' , default = default_arguments [ ' low-vram ' ] , help = " Disables some optimizations that increases VRAM usage " )
parser . add_argument ( " --no-embed-output-metadata " , action = ' store_false ' , default = not default_arguments [ ' embed-output-metadata ' ] , help = " Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag) " )
parser . add_argument ( " --latents-lean-and-mean " , action = ' store_true ' , default = default_arguments [ ' latents-lean-and-mean ' ] , help = " Exports the bare essentials for latents. " )
parser . add_argument ( " --voice-fixer " , action = ' store_true ' , default = default_arguments [ ' voice-fixer ' ] , help = " Uses python module ' voicefixer ' to improve audio quality, if available. " )
parser . add_argument ( " --voice-fixer-use-cuda " , action = ' store_true ' , default = default_arguments [ ' voice-fixer-use-cuda ' ] , help = " Hints to voicefixer to use CUDA, if available. " )
parser . add_argument ( " --force-cpu-for-conditioning-latents " , default = default_arguments [ ' force-cpu-for-conditioning-latents ' ] , action = ' store_true ' , help = " Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts) " )
parser . add_argument ( " --defer-tts-load " , default = default_arguments [ ' defer-tts-load ' ] , action = ' store_true ' , help = " Defers loading TTS model " )
2023-02-21 21:50:05 +00:00
parser . add_argument ( " --prune-nonfinal-outputs " , default = default_arguments [ ' prune-nonfinal-outputs ' ] , action = ' store_true ' , help = " Deletes non-final output files on completing a generation " )
2023-03-07 02:45:22 +00:00
parser . add_argument ( " --vocoder-model " , default = default_arguments [ ' vocoder-model ' ] , action = ' store_true ' , help = " Specifies with vocoder to use " )
2023-02-20 00:21:16 +00:00
parser . add_argument ( " --device-override " , default = default_arguments [ ' device-override ' ] , help = " A device string to override pass through Torch " )
parser . add_argument ( " --sample-batch-size " , default = default_arguments [ ' sample-batch-size ' ] , type = int , help = " Sets how many batches to use during the autoregressive samples pass " )
parser . add_argument ( " --concurrency-count " , type = int , default = default_arguments [ ' concurrency-count ' ] , help = " How many Gradio events to process at once " )
2023-03-03 21:13:48 +00:00
parser . add_argument ( " --autocalculate-voice-chunk-duration-size " , type = float , default = default_arguments [ ' autocalculate-voice-chunk-duration-size ' ] , help = " Number of seconds to suggest voice chunk size for (for example, 100 seconds of audio at 10 seconds per chunk will suggest 10 chunks) " )
2023-02-20 00:21:16 +00:00
parser . add_argument ( " --output-sample-rate " , type = int , default = default_arguments [ ' output-sample-rate ' ] , help = " Sample rate to resample the output to (from 24KHz) " )
parser . add_argument ( " --output-volume " , type = float , default = default_arguments [ ' output-volume ' ] , help = " Adjusts volume of output " )
2023-02-27 19:20:06 +00:00
parser . add_argument ( " --autoregressive-model " , default = default_arguments [ ' autoregressive-model ' ] , help = " Specifies which autoregressive model to use for sampling. " )
2023-03-06 05:21:33 +00:00
parser . add_argument ( " --whisper-backend " , default = default_arguments [ ' whisper-backend ' ] , action = ' store_true ' , help = " Picks which whisper backend to use (openai/whisper, lightmare/whispercpp, m-bain/whisperx) " )
2023-02-27 19:20:06 +00:00
parser . add_argument ( " --whisper-model " , default = default_arguments [ ' whisper-model ' ] , help = " Specifies which whisper model to use for transcription. " )
2023-02-26 01:57:56 +00:00
parser . add_argument ( " --training-default-halfp " , action = ' store_true ' , default = default_arguments [ ' training-default-halfp ' ] , help = " Training default: halfp " )
parser . add_argument ( " --training-default-bnb " , action = ' store_true ' , default = default_arguments [ ' training-default-bnb ' ] , help = " Training default: bnb " )
2023-02-20 00:21:16 +00:00
parser . add_argument ( " --os " , default = " unix " , help = " Specifies which OS, easily " )
args = parser . parse_args ( )
2023-02-18 14:10:26 +00:00
2023-02-20 00:21:16 +00:00
args . embed_output_metadata = not args . no_embed_output_metadata
2023-02-19 05:10:08 +00:00
2023-02-20 00:21:16 +00:00
if not args . device_override :
set_device_name ( args . device_override )
2023-02-19 05:10:08 +00:00
2023-03-03 04:37:18 +00:00
2023-02-20 00:21:16 +00:00
args . listen_host = None
args . listen_port = None
args . listen_path = None
if args . listen :
try :
2023-03-03 04:37:18 +00:00
match = re . findall ( r " ^(?:(.+?):( \ d+))?( \ /.*?)?$ " , args . listen ) [ 0 ]
2023-02-19 05:10:08 +00:00
2023-02-20 00:21:16 +00:00
args . listen_host = match [ 0 ] if match [ 0 ] != " " else " 127.0.0.1 "
args . listen_port = match [ 1 ] if match [ 1 ] != " " else None
args . listen_path = match [ 2 ] if match [ 2 ] != " " else " / "
except Exception as e :
pass
2023-02-19 05:10:08 +00:00
2023-02-20 00:21:16 +00:00
if args . listen_port is not None :
args . listen_port = int ( args . listen_port )
2023-03-03 18:51:33 +00:00
if args . listen_port == 0 :
args . listen_port = None
2023-02-18 20:37:37 +00:00
2023-02-20 00:21:16 +00:00
return args
2023-02-18 14:10:26 +00:00
2023-03-07 02:45:22 +00:00
def update_args ( listen , share , check_for_updates , models_from_local_only , low_vram , embed_output_metadata , latents_lean_and_mean , voice_fixer , voice_fixer_use_cuda , force_cpu_for_conditioning_latents , defer_tts_load , prune_nonfinal_outputs , device_override , sample_batch_size , concurrency_count , autocalculate_voice_chunk_duration_size , output_volume , autoregressive_model , vocoder_model , whisper_backend , whisper_model , training_default_halfp , training_default_bnb ) :
2023-02-17 03:05:27 +00:00
global args
args . listen = listen
args . share = share
args . check_for_updates = check_for_updates
args . models_from_local_only = models_from_local_only
args . low_vram = low_vram
args . force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
2023-02-17 20:43:12 +00:00
args . defer_tts_load = defer_tts_load
2023-02-21 21:50:05 +00:00
args . prune_nonfinal_outputs = prune_nonfinal_outputs
2023-02-17 03:05:27 +00:00
args . device_override = device_override
args . sample_batch_size = sample_batch_size
args . embed_output_metadata = embed_output_metadata
args . latents_lean_and_mean = latents_lean_and_mean
args . voice_fixer = voice_fixer
args . voice_fixer_use_cuda = voice_fixer_use_cuda
args . concurrency_count = concurrency_count
2023-03-03 21:13:48 +00:00
args . output_sample_rate = 44000
args . autocalculate_voice_chunk_duration_size = autocalculate_voice_chunk_duration_size
2023-02-17 03:05:27 +00:00
args . output_volume = output_volume
2023-02-27 19:20:06 +00:00
args . autoregressive_model = autoregressive_model
2023-03-07 02:45:22 +00:00
args . vocoder_model = vocoder_model
2023-03-06 05:21:33 +00:00
args . whisper_backend = whisper_backend
2023-02-27 19:20:06 +00:00
args . whisper_model = whisper_model
2023-02-26 01:57:56 +00:00
args . training_default_halfp = training_default_halfp
args . training_default_bnb = training_default_bnb
2023-02-17 03:05:27 +00:00
2023-02-18 14:10:26 +00:00
save_args_settings ( )
def save_args_settings ( ) :
2023-02-26 01:57:56 +00:00
global args
2023-02-17 03:05:27 +00:00
settings = {
2023-03-03 07:23:10 +00:00
' listen ' : None if not args . listen else args . listen ,
2023-02-17 03:05:27 +00:00
' share ' : args . share ,
' low-vram ' : args . low_vram ,
' check-for-updates ' : args . check_for_updates ,
' models-from-local-only ' : args . models_from_local_only ,
' force-cpu-for-conditioning-latents ' : args . force_cpu_for_conditioning_latents ,
2023-02-17 20:43:12 +00:00
' defer-tts-load ' : args . defer_tts_load ,
2023-02-21 21:50:05 +00:00
' prune-nonfinal-outputs ' : args . prune_nonfinal_outputs ,
2023-02-17 03:05:27 +00:00
' device-override ' : args . device_override ,
' sample-batch-size ' : args . sample_batch_size ,
' embed-output-metadata ' : args . embed_output_metadata ,
' latents-lean-and-mean ' : args . latents_lean_and_mean ,
' voice-fixer ' : args . voice_fixer ,
' voice-fixer-use-cuda ' : args . voice_fixer_use_cuda ,
' concurrency-count ' : args . concurrency_count ,
' output-sample-rate ' : args . output_sample_rate ,
2023-03-03 21:13:48 +00:00
' autocalculate-voice-chunk-duration-size ' : args . autocalculate_voice_chunk_duration_size ,
2023-02-17 03:05:27 +00:00
' output-volume ' : args . output_volume ,
2023-02-27 19:20:06 +00:00
' autoregressive-model ' : args . autoregressive_model ,
2023-03-07 02:45:22 +00:00
' vocoder-model ' : args . vocoder_model ,
2023-03-06 05:21:33 +00:00
' whisper-backend ' : args . whisper_backend ,
2023-02-27 19:20:06 +00:00
' whisper-model ' : args . whisper_model ,
2023-02-26 01:57:56 +00:00
' training-default-halfp ' : args . training_default_halfp ,
' training-default-bnb ' : args . training_default_bnb ,
2023-02-17 03:05:27 +00:00
}
2023-02-17 20:10:27 +00:00
os . makedirs ( ' ./config/ ' , exist_ok = True )
2023-02-17 03:05:27 +00:00
with open ( f ' ./config/exec.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( settings , indent = ' \t ' ) )
2023-02-20 00:21:16 +00:00
def import_generate_settings ( file = " ./config/generate.json " ) :
settings , _ = read_generate_settings ( file , read_latents = False )
if settings is None :
return None
return (
None if ' text ' not in settings else settings [ ' text ' ] ,
None if ' delimiter ' not in settings else settings [ ' delimiter ' ] ,
None if ' emotion ' not in settings else settings [ ' emotion ' ] ,
None if ' prompt ' not in settings else settings [ ' prompt ' ] ,
None if ' voice ' not in settings else settings [ ' voice ' ] ,
None ,
None ,
None if ' seed ' not in settings else settings [ ' seed ' ] ,
None if ' candidates ' not in settings else settings [ ' candidates ' ] ,
None if ' num_autoregressive_samples ' not in settings else settings [ ' num_autoregressive_samples ' ] ,
None if ' diffusion_iterations ' not in settings else settings [ ' diffusion_iterations ' ] ,
0.8 if ' temperature ' not in settings else settings [ ' temperature ' ] ,
" DDIM " if ' diffusion_sampler ' not in settings else settings [ ' diffusion_sampler ' ] ,
8 if ' breathing_room ' not in settings else settings [ ' breathing_room ' ] ,
0.0 if ' cvvp_weight ' not in settings else settings [ ' cvvp_weight ' ] ,
0.8 if ' top_p ' not in settings else settings [ ' top_p ' ] ,
1.0 if ' diffusion_temperature ' not in settings else settings [ ' diffusion_temperature ' ] ,
1.0 if ' length_penalty ' not in settings else settings [ ' length_penalty ' ] ,
2.0 if ' repetition_penalty ' not in settings else settings [ ' repetition_penalty ' ] ,
2.0 if ' cond_free_k ' not in settings else settings [ ' cond_free_k ' ] ,
None if ' experimentals ' not in settings else settings [ ' experimentals ' ] ,
)
def reset_generation_settings ( ) :
with open ( f ' ./config/generate.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( { } , indent = ' \t ' ) )
return import_generate_settings ( )
2023-02-21 03:00:45 +00:00
def read_generate_settings ( file , read_latents = True ) :
2023-02-17 03:05:27 +00:00
j = None
latents = None
2023-02-21 03:00:45 +00:00
if isinstance ( file , list ) and len ( file ) == 1 :
file = file [ 0 ]
2023-02-21 17:35:30 +00:00
try :
if file is not None :
if hasattr ( file , ' name ' ) :
file = file . name
if file [ - 4 : ] == " .wav " :
metadata = music_tag . load_file ( file )
if ' lyrics ' in metadata :
j = json . loads ( str ( metadata [ ' lyrics ' ] ) )
elif file [ - 5 : ] == " .json " :
with open ( file , ' r ' ) as f :
j = json . load ( f )
except Exception as e :
pass
2023-02-17 03:05:27 +00:00
2023-02-28 15:36:06 +00:00
if j is not None :
2023-02-17 03:05:27 +00:00
if ' latents ' in j :
if read_latents :
latents = base64 . b64decode ( j [ ' latents ' ] )
del j [ ' latents ' ]
if " time " in j :
j [ " time " ] = " {:.3f} " . format ( j [ " time " ] )
2023-02-21 03:00:45 +00:00
2023-02-17 03:05:27 +00:00
return (
j ,
latents ,
2023-02-18 02:07:22 +00:00
)
2023-03-06 21:48:34 +00:00
def version_check_tts ( min_version ) :
global tts
if not tts :
raise Exception ( " TTS is not initialized " )
if not hasattr ( tts , ' version ' ) :
return False
if min_version [ 0 ] > tts . version [ 0 ] :
return True
if min_version [ 1 ] > tts . version [ 1 ] :
return True
if min_version [ 2 ] > = tts . version [ 2 ] :
return True
return False
2023-03-07 04:34:39 +00:00
def load_tts ( restart = False , autoregressive_model = None ) :
2023-02-20 00:21:16 +00:00
global args
global tts
2023-02-18 02:07:22 +00:00
2023-02-20 00:21:16 +00:00
if restart :
unload_tts ( )
2023-02-18 02:07:22 +00:00
2023-03-07 04:34:39 +00:00
if autoregressive_model :
args . autoregressive_model = autoregressive_model
else :
autoregressive_model = args . autoregressive_model
2023-02-21 03:00:45 +00:00
2023-03-07 04:34:39 +00:00
if autoregressive_model == " auto " :
autoregressive_model = deduce_autoregressive_model ( )
2023-02-21 03:00:45 +00:00
2023-03-07 04:34:39 +00:00
print ( f " Loading TorToiSe... (AR: { autoregressive_model } , vocoder: { args . vocoder_model } ) " )
2023-02-21 03:00:45 +00:00
tts_loading = True
2023-02-20 00:21:16 +00:00
try :
2023-03-07 04:34:39 +00:00
tts = TextToSpeech ( minor_optimizations = not args . low_vram , autoregressive_model_path = autoregressive_model , vocoder_model = args . vocoder_model )
2023-02-20 00:21:16 +00:00
except Exception as e :
tts = TextToSpeech ( minor_optimizations = not args . low_vram )
2023-03-07 04:34:39 +00:00
load_autoregressive_model ( autoregressive_model )
2023-02-28 15:36:06 +00:00
2023-02-21 03:00:45 +00:00
tts_loading = False
2023-02-18 02:07:22 +00:00
2023-02-20 00:21:16 +00:00
get_model_path ( ' dvae.pth ' )
print ( " Loaded TorToiSe, ready for generation. " )
return tts
setup_tortoise = load_tts
def unload_tts ( ) :
global tts
2023-02-18 02:07:22 +00:00
2023-02-20 00:21:16 +00:00
if tts :
del tts
tts = None
2023-03-05 05:17:19 +00:00
print ( " Unloaded TTS " )
2023-02-20 00:21:16 +00:00
do_gc ( )
2023-02-21 03:00:45 +00:00
def reload_tts ( model = None ) :
load_tts ( restart = True , model = model )
2023-02-20 00:21:16 +00:00
2023-03-07 04:34:39 +00:00
def get_current_voice ( ) :
global current_voice
if current_voice :
return current_voice
settings , _ = read_generate_settings ( " ./config/generate.json " , read_latents = False )
if settings and " voice " in settings [ ' voice ' ] :
return settings [ " voice " ]
return None
def deduce_autoregressive_model ( voice = None ) :
if not voice :
voice = get_current_voice ( )
if voice :
dir = f ' ./training/ { voice } -finetune/models/ '
if os . path . exists ( f ' ./training/finetunes/ { voice } .pth ' ) :
return f ' ./training/finetunes/ { voice } .pth '
if os . path . isdir ( dir ) :
counts = sorted ( [ int ( d [ : - 8 ] ) for d in os . listdir ( dir ) if d [ - 8 : ] == " _gpt.pth " ] )
names = [ f ' { dir } / { d } _gpt.pth ' for d in counts ]
return names [ - 1 ]
if args . autoregressive_model != " auto " :
return args . autoregressive_model
return get_model_path ( ' autoregressive.pth ' )
2023-02-20 00:21:16 +00:00
def update_autoregressive_model ( autoregressive_model_path ) :
2023-02-28 15:36:06 +00:00
match = re . findall ( r ' ^ \ [[a-fA-F0-9] {8} \ ] (.+?)$ ' , autoregressive_model_path )
if match :
autoregressive_model_path = match [ 0 ]
2023-02-21 03:00:45 +00:00
if not autoregressive_model_path or not os . path . exists ( autoregressive_model_path ) :
2023-02-24 13:05:08 +00:00
print ( f " Invalid model: { autoregressive_model_path } " )
2023-02-21 03:00:45 +00:00
return
2023-02-20 00:21:16 +00:00
args . autoregressive_model = autoregressive_model_path
save_args_settings ( )
print ( f ' Stored autoregressive model to settings: { autoregressive_model_path } ' )
global tts
if not tts :
2023-02-21 03:00:45 +00:00
if tts_loading :
raise Exception ( " TTS is still initializing... " )
2023-02-27 19:20:06 +00:00
return
2023-03-07 03:55:35 +00:00
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
2023-03-07 04:34:39 +00:00
if autoregressive_model_path == " auto " :
autoregressive_model_path = deduce_autoregressive_model ( )
if autoregressive_model_path == tts . autoregressive_model_path :
return
2023-02-20 00:21:16 +00:00
2023-03-07 02:45:22 +00:00
tts . load_autoregressive_model ( autoregressive_model_path )
2023-02-20 00:21:16 +00:00
2023-03-07 02:45:22 +00:00
do_gc ( )
return autoregressive_model_path
2023-02-20 00:21:16 +00:00
2023-03-07 02:45:22 +00:00
def update_vocoder_model ( vocoder_model ) :
args . vocoder_model = vocoder_model
save_args_settings ( )
print ( f ' Stored vocoder model to settings: { vocoder_model } ' )
2023-02-20 00:21:16 +00:00
2023-03-07 02:45:22 +00:00
global tts
if not tts :
if tts_loading :
raise Exception ( " TTS is still initializing... " )
return
2023-02-28 15:36:06 +00:00
2023-03-07 03:55:35 +00:00
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
2023-03-07 02:45:22 +00:00
print ( f " Loading model: { vocoder_model } " )
tts . load_vocoder_model ( vocoder_model )
print ( f " Loaded model: { tts . vocoder_model } " )
2023-02-20 00:21:16 +00:00
do_gc ( )
2023-03-07 02:45:22 +00:00
return vocoder_model
2023-02-20 00:21:16 +00:00
def load_voicefixer ( restart = False ) :
global voicefixer
if restart :
unload_voicefixer ( )
try :
print ( " Loading Voicefixer " )
from voicefixer import VoiceFixer
voicefixer = VoiceFixer ( )
2023-02-24 23:13:13 +00:00
print ( " Loaded Voicefixer " )
2023-02-20 00:21:16 +00:00
except Exception as e :
print ( f " Error occurred while tring to initialize voicefixer: { e } " )
def unload_voicefixer ( ) :
global voicefixer
if voicefixer :
del voicefixer
voicefixer = None
2023-02-24 23:13:13 +00:00
print ( " Unloaded Voicefixer " )
2023-02-20 00:21:16 +00:00
do_gc ( )
2023-03-05 05:17:19 +00:00
def load_whisper_model ( language = None , model_name = None , progress = None ) :
2023-02-20 15:31:38 +00:00
global whisper_model
2023-03-06 05:21:33 +00:00
if args . whisper_backend not in WHISPER_BACKENDS :
raise Exception ( f " unavailable backend: { args . whisper_backend } " )
if args . whisper_backend != " m-bain/whisperx " and model_name == " large-v2 " :
raise Exception ( " large-v2 is only available for m-bain/whisperx backend " )
2023-02-20 15:31:38 +00:00
2023-03-05 05:17:19 +00:00
if not model_name :
model_name = args . whisper_model
2023-02-20 00:21:16 +00:00
else :
2023-03-05 05:17:19 +00:00
args . whisper_model = model_name
2023-02-24 23:13:13 +00:00
save_args_settings ( )
2023-02-20 00:21:16 +00:00
2023-03-05 05:17:19 +00:00
if language and f ' { model_name } . { language } ' in WHISPER_SPECIALIZED_MODELS :
model_name = f ' { model_name } . { language } '
print ( f " Loading specialized model for language: { language } " )
notify_progress ( f " Loading Whisper model: { model_name } " , progress )
2023-03-05 17:54:36 +00:00
2023-03-06 05:21:33 +00:00
if args . whisper_backend == " openai/whisper " :
import whisper
whisper_model = whisper . load_model ( model_name )
elif args . whisper_backend == " lightmare/whispercpp " :
2023-02-27 19:20:06 +00:00
from whispercpp import Whisper
2023-03-05 05:17:19 +00:00
if not language :
language = ' auto '
2023-03-05 17:54:36 +00:00
b_lang = language . encode ( ' ascii ' )
whisper_model = Whisper ( model_name , models_dir = ' ./models/ ' , language = b_lang )
2023-03-06 05:21:33 +00:00
elif args . whisper_backend == " m-bain/whisperx " :
import whisperx
device = " cuda " if get_device_name ( ) == " cuda " else " cpu "
whisper_model = whisperx . load_model ( model_name , device )
2023-03-05 05:17:19 +00:00
2023-02-24 23:13:13 +00:00
print ( " Loaded Whisper model " )
2023-02-20 00:21:16 +00:00
def unload_whisper ( ) :
global whisper_model
if whisper_model :
del whisper_model
whisper_model = None
2023-02-24 23:13:13 +00:00
print ( " Unloaded Whisper " )
2023-02-20 00:21:16 +00:00
2023-03-06 05:21:33 +00:00
do_gc ( )