2023-02-17 00:08:27 +00:00
import os
if ' XDG_CACHE_HOME ' not in os . environ :
os . environ [ ' XDG_CACHE_HOME ' ] = os . path . realpath ( os . path . join ( os . getcwd ( ) , ' ./models/ ' ) )
if ' TORTOISE_MODELS_DIR ' not in os . environ :
os . environ [ ' TORTOISE_MODELS_DIR ' ] = os . path . realpath ( os . path . join ( os . getcwd ( ) , ' ./models/tortoise/ ' ) )
if ' TRANSFORMERS_CACHE ' not in os . environ :
os . environ [ ' TRANSFORMERS_CACHE ' ] = os . path . realpath ( os . path . join ( os . getcwd ( ) , ' ./models/transformers/ ' ) )
import argparse
import time
2023-03-17 05:33:49 +00:00
import math
2023-02-17 00:08:27 +00:00
import json
import base64
import re
import urllib . request
2023-02-18 02:07:22 +00:00
import signal
2023-02-18 20:37:37 +00:00
import gc
2023-02-20 00:21:16 +00:00
import subprocess
2023-03-04 20:42:54 +00:00
import psutil
2023-02-23 06:24:54 +00:00
import yaml
2023-03-09 04:33:12 +00:00
import hashlib
2023-03-16 04:25:33 +00:00
import string
2023-03-31 03:26:00 +00:00
import random
2023-02-17 00:08:27 +00:00
2023-03-16 20:48:48 +00:00
from tqdm import tqdm
2023-02-17 00:08:27 +00:00
import torch
import torchaudio
import music_tag
import gradio as gr
import gradio . utils
2023-02-28 01:01:50 +00:00
import pandas as pd
2023-04-26 04:48:09 +00:00
import numpy as np
2023-02-17 00:08:27 +00:00
2023-04-13 21:10:38 +00:00
from glob import glob
2023-02-17 00:08:27 +00:00
from datetime import datetime
2023-02-23 06:24:54 +00:00
from datetime import timedelta
2023-02-17 00:08:27 +00:00
2023-03-31 03:26:00 +00:00
from tortoise . api import TextToSpeech as TorToise_TTS , MODELS , get_model_path , pad_or_truncate
2023-03-11 21:17:11 +00:00
from tortoise . utils . audio import load_audio , load_voice , load_voices , get_voice_dir , get_voices
2023-02-17 00:08:27 +00:00
from tortoise . utils . text import split_and_recombine_text
2023-03-15 01:20:15 +00:00
from tortoise . utils . device import get_device_name , set_device_name , get_device_count , get_device_vram , get_device_batch_size , do_gc
2023-02-17 00:08:27 +00:00
2023-03-13 19:07:23 +00:00
from whisper . normalizers . english import EnglishTextNormalizer
2023-03-16 04:25:33 +00:00
from whisper . normalizers . basic import BasicTextNormalizer
from whisper . tokenizer import LANGUAGES
2023-02-17 19:06:05 +00:00
MODELS [ ' dvae.pth ' ] = " https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth "
2023-03-06 05:21:33 +00:00
2023-03-11 16:40:34 +00:00
WHISPER_MODELS = [ " tiny " , " base " , " small " , " medium " , " large " ]
2023-03-05 05:17:19 +00:00
WHISPER_SPECIALIZED_MODELS = [ " tiny.en " , " base.en " , " small.en " , " medium.en " ]
2023-03-22 19:24:53 +00:00
WHISPER_BACKENDS = [ " openai/whisper " , " lightmare/whispercpp " , " m-bain/whisperx " ]
2023-03-07 13:40:41 +00:00
VOCODERS = [ ' univnet ' , ' bigvgan_base_24khz_100band ' , ' bigvgan_24khz_100band ' ]
2023-03-14 15:48:09 +00:00
TTSES = [ ' tortoise ' ]
2023-03-07 02:45:22 +00:00
2023-03-14 17:42:42 +00:00
INFERENCING = False
2023-03-09 00:26:47 +00:00
GENERATE_SETTINGS_ARGS = None
2023-03-09 14:17:01 +00:00
LEARNING_RATE_SCHEMES = { " Multistep " : " MultiStepLR " , " Cos. Annealing " : " CosineAnnealingLR_Restart " }
2023-03-14 02:29:11 +00:00
LEARNING_RATE_SCHEDULE = [ 2 , 4 , 9 , 18 , 25 , 33 , 50 ]
2023-02-17 19:06:05 +00:00
2023-03-13 01:20:55 +00:00
RESAMPLERS = { }
2023-03-13 04:26:00 +00:00
MIN_TRAINING_DURATION = 0.6
MAX_TRAINING_DURATION = 11.6097505669
2023-03-14 05:02:14 +00:00
VALLE_ENABLED = False
2023-04-26 04:48:09 +00:00
BARK_ENABLED = False
2023-03-14 05:02:14 +00:00
try :
2023-03-23 00:22:25 +00:00
from vall_e . emb . qnt import encode as valle_quantize
2023-03-23 01:55:16 +00:00
from vall_e . emb . g2p import encode as valle_phonemize
2023-03-14 05:02:14 +00:00
2023-03-31 03:26:00 +00:00
from vall_e . inference import TTS as VALLE_TTS
import soundfile
2023-03-14 05:02:14 +00:00
VALLE_ENABLED = True
except Exception as e :
2023-04-26 04:48:09 +00:00
if False : # args.tts_backend == "vall-e":
raise e
2023-03-14 05:02:14 +00:00
pass
2023-03-14 15:48:09 +00:00
if VALLE_ENABLED :
TTSES . append ( ' vall-e ' )
2023-04-26 04:48:09 +00:00
try :
from bark . generation import SAMPLE_RATE as BARK_SAMPLE_RATE , ALLOWED_PROMPTS , preload_models , codec_decode , generate_coarse , generate_fine , generate_text_semantic , load_codec_model
from bark . api import generate_audio as bark_generate_audio
from encodec . utils import convert_audio
from scipy . io . wavfile import write as write_wav
BARK_ENABLED = True
except Exception as e :
if False : # args.tts_backend == "bark":
raise e
pass
if BARK_ENABLED :
TTSES . append ( ' bark ' )
class Bark_TTS ( ) :
def __init__ ( self , small = False ) :
self . input_sample_rate = BARK_SAMPLE_RATE
self . output_sample_rate = args . output_sample_rate
preload_models (
2023-04-28 15:56:57 +00:00
text_use_gpu = True ,
coarse_use_gpu = True ,
fine_use_gpu = True ,
codec_use_gpu = True ,
text_use_small = small ,
coarse_use_small = small ,
fine_use_small = small ,
force_reload = False
2023-04-26 04:48:09 +00:00
)
def create_voice ( self , voice , device = ' cuda ' ) :
transcription_json = f ' ./training/ { voice } /whisper.json '
if not os . path . exists ( transcription_json ) :
raise f " Transcription for voice not found: { voice } "
transcriptions = json . load ( open ( transcription_json , ' r ' , encoding = " utf-8 " ) )
candidates = [ ]
for file in transcriptions :
result = transcriptions [ file ]
for segment in result [ ' segments ' ] :
entry = (
file . replace ( " .wav " , f " _ { pad ( segment [ ' id ' ] , 4 ) } .wav " ) ,
segment [ ' end ' ] - segment [ ' start ' ] ,
segment [ ' text ' ]
)
candidates . append ( entry )
candidates . sort ( key = lambda x : x [ 1 ] )
candidate = random . choice ( candidates )
audio_filepath = f ' ./training/ { voice } /audio/ { candidate [ 0 ] } '
text = candidate [ - 1 ]
print ( " Using as reference: " , audio_filepath , text )
# Load and pre-process the audio waveform
model = load_codec_model ( use_gpu = True )
wav , sr = torchaudio . load ( audio_filepath )
wav = convert_audio ( wav , sr , model . sample_rate , model . channels )
wav = wav . unsqueeze ( 0 ) . to ( device )
# Extract discrete codes from EnCodec
with torch . no_grad ( ) :
2023-04-28 15:56:57 +00:00
encoded_frames = model . encode ( wav )
2023-04-26 04:48:09 +00:00
codes = torch . cat ( [ encoded [ 0 ] for encoded in encoded_frames ] , dim = - 1 ) . squeeze ( ) . cpu ( ) . numpy ( ) # [n_q, T]
# get seconds of audio
seconds = wav . shape [ - 1 ] / model . sample_rate
# generate semantic tokens
semantic_tokens = generate_text_semantic ( text , max_gen_duration_s = seconds , top_k = 50 , top_p = .95 , temp = 0.7 )
output_path = ' ./modules/bark/bark/assets/prompts/ ' + voice . replace ( " / " , " _ " ) + ' .npz '
np . savez ( output_path , fine_prompt = codes , coarse_prompt = codes [ : 2 , : ] , semantic_prompt = semantic_tokens )
def inference ( self , text , voice , text_temp = 0.7 , waveform_temp = 0.7 ) :
if not os . path . exists ( ' ./modules/bark/bark/assets/prompts/ ' + voice + ' .npz ' ) :
self . create_voice ( voice )
voice = voice . replace ( " / " , " _ " )
if voice not in ALLOWED_PROMPTS :
ALLOWED_PROMPTS . add ( voice )
return ( bark_generate_audio ( text , history_prompt = voice , text_temp = text_temp , waveform_temp = waveform_temp ) , BARK_SAMPLE_RATE )
2023-02-17 00:08:27 +00:00
args = None
tts = None
2023-02-21 03:00:45 +00:00
tts_loading = False
2023-02-17 00:08:27 +00:00
webui = None
voicefixer = None
2023-03-23 00:22:25 +00:00
2023-02-17 19:06:05 +00:00
whisper_model = None
2023-03-22 19:24:53 +00:00
whisper_vad = None
whisper_diarize = None
2023-03-23 00:22:25 +00:00
whisper_align_model = None
2023-02-23 06:24:54 +00:00
training_state = None
2023-02-17 05:42:55 +00:00
2023-03-07 04:34:39 +00:00
current_voice = None
2023-04-26 04:48:09 +00:00
def cleanup_voice_name ( name ) :
return name . split ( " / " ) [ - 1 ]
2023-03-13 01:20:55 +00:00
def resample ( waveform , input_rate , output_rate = 44100 ) :
2023-03-13 04:26:00 +00:00
# mono-ize
waveform = torch . mean ( waveform , dim = 0 , keepdim = True )
2023-03-13 01:20:55 +00:00
if input_rate == output_rate :
return waveform , output_rate
key = f ' { input_rate } : { output_rate } '
if not key in RESAMPLERS :
RESAMPLERS [ key ] = torchaudio . transforms . Resample (
input_rate ,
output_rate ,
lowpass_filter_width = 16 ,
rolloff = 0.85 ,
resampling_method = " kaiser_window " ,
beta = 8.555504641634386 ,
)
return RESAMPLERS [ key ] ( waveform ) , output_rate
2023-03-09 00:26:47 +00:00
def generate ( * * kwargs ) :
2023-03-31 03:26:00 +00:00
if args . tts_backend == " tortoise " :
return generate_tortoise ( * * kwargs )
if args . tts_backend == " vall-e " :
return generate_valle ( * * kwargs )
2023-04-26 04:48:09 +00:00
if args . tts_backend == " bark " :
return generate_bark ( * * kwargs )
def generate_bark ( * * kwargs ) :
parameters = { }
parameters . update ( kwargs )
voice = parameters [ ' voice ' ]
progress = parameters [ ' progress ' ] if ' progress ' in parameters else None
if parameters [ ' seed ' ] == 0 :
parameters [ ' seed ' ] = None
usedSeed = parameters [ ' seed ' ]
global args
global tts
unload_whisper ( )
unload_voicefixer ( )
if not tts :
# should check if it's loading or unloaded, and load it if it's unloaded
if tts_loading :
raise Exception ( " TTS is still initializing... " )
if progress is not None :
2023-05-04 23:40:33 +00:00
notify_progress ( " Initializing TTS... " , progress = progress )
2023-04-26 04:48:09 +00:00
load_tts ( )
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
do_gc ( )
voice_samples = None
conditioning_latents = None
sample_voice = None
voice_cache = { }
def get_settings ( override = None ) :
settings = {
' voice ' : parameters [ ' voice ' ] ,
' text_temp ' : float ( parameters [ ' temperature ' ] ) ,
' waveform_temp ' : float ( parameters [ ' temperature ' ] ) ,
}
# could be better to just do a ternary on everything above, but i am not a professional
selected_voice = voice
if override is not None :
if ' voice ' in override :
selected_voice = override [ ' voice ' ]
for k in override :
if k not in settings :
continue
settings [ k ] = override [ k ]
return settings
if not parameters [ ' delimiter ' ] :
parameters [ ' delimiter ' ] = " \n "
elif parameters [ ' delimiter ' ] == " \\ n " :
parameters [ ' delimiter ' ] = " \n "
if parameters [ ' delimiter ' ] and parameters [ ' delimiter ' ] != " " and parameters [ ' delimiter ' ] in parameters [ ' text ' ] :
texts = parameters [ ' text ' ] . split ( parameters [ ' delimiter ' ] )
else :
texts = split_and_recombine_text ( parameters [ ' text ' ] )
full_start_time = time . time ( )
outdir = f " { args . results_folder } / { voice } / "
os . makedirs ( outdir , exist_ok = True )
audio_cache = { }
volume_adjust = torchaudio . transforms . Vol ( gain = args . output_volume , gain_type = " amplitude " ) if args . output_volume != 1 else None
idx = 0
idx_cache = { }
for i , file in enumerate ( os . listdir ( outdir ) ) :
filename = os . path . basename ( file )
extension = os . path . splitext ( filename ) [ 1 ]
if extension != " .json " and extension != " .wav " :
continue
match = re . findall ( rf " ^ { cleanup_voice_name ( voice ) } _( \ d+)(?:.+?)? { extension } $ " , filename )
if match and len ( match ) > 0 :
key = int ( match [ 0 ] )
idx_cache [ key ] = True
if len ( idx_cache ) > 0 :
keys = sorted ( list ( idx_cache . keys ( ) ) )
idx = keys [ - 1 ] + 1
idx = pad ( idx , 4 )
def get_name ( line = 0 , candidate = 0 , combined = False ) :
name = f " { idx } "
if combined :
name = f " { name } _combined "
elif len ( texts ) > 1 :
name = f " { name } _ { line } "
if parameters [ ' candidates ' ] > 1 :
name = f " { name } _ { candidate } "
return name
def get_info ( voice , settings = None , latents = True ) :
info = { }
info . update ( parameters )
info [ ' time ' ] = time . time ( ) - full_start_time
info [ ' datetime ' ] = datetime . now ( ) . isoformat ( )
info [ ' progress ' ] = None
del info [ ' progress ' ]
if info [ ' delimiter ' ] == " \n " :
info [ ' delimiter ' ] = " \\ n "
if settings is not None :
for k in settings :
if k in info :
info [ k ] = settings [ k ]
return info
INFERENCING = True
for line , cut_text in enumerate ( texts ) :
2023-05-04 23:40:33 +00:00
tqdm_prefix = f ' [ { str ( line + 1 ) } / { str ( len ( texts ) ) } ] '
print ( f " { tqdm_prefix } Generating line: { cut_text } " )
2023-04-26 04:48:09 +00:00
start_time = time . time ( )
# do setting editing
match = re . findall ( r ' ^( \ { .+ \ }) (.+?)$ ' , cut_text )
override = None
if match and len ( match ) > 0 :
match = match [ 0 ]
try :
override = json . loads ( match [ 0 ] )
cut_text = match [ 1 ] . strip ( )
except Exception as e :
raise Exception ( " Prompt settings editing requested, but received invalid JSON " )
settings = get_settings ( override = override )
gen = tts . inference ( cut_text , * * settings )
run_time = time . time ( ) - start_time
print ( f " Generating line took { run_time } seconds " )
if not isinstance ( gen , list ) :
gen = [ gen ]
for j , g in enumerate ( gen ) :
wav , sr = g
name = get_name ( line = line , candidate = j )
settings [ ' text ' ] = cut_text
settings [ ' time ' ] = run_time
settings [ ' datetime ' ] = datetime . now ( ) . isoformat ( )
# save here in case some error happens mid-batch
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
write_wav ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' , sr , wav )
wav , sr = torchaudio . load ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' )
audio_cache [ name ] = {
' audio ' : wav ,
' settings ' : get_info ( voice = override [ ' voice ' ] if override and ' voice ' in override else voice , settings = settings )
}
del gen
do_gc ( )
INFERENCING = False
for k in audio_cache :
audio = audio_cache [ k ] [ ' audio ' ]
audio , _ = resample ( audio , tts . output_sample_rate , args . output_sample_rate )
if volume_adjust is not None :
audio = volume_adjust ( audio )
audio_cache [ k ] [ ' audio ' ] = audio
torchaudio . save ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { k } .wav ' , audio , args . output_sample_rate )
output_voices = [ ]
for candidate in range ( parameters [ ' candidates ' ] ) :
if len ( texts ) > 1 :
audio_clips = [ ]
for line in range ( len ( texts ) ) :
name = get_name ( line = line , candidate = candidate )
audio = audio_cache [ name ] [ ' audio ' ]
audio_clips . append ( audio )
name = get_name ( candidate = candidate , combined = True )
audio = torch . cat ( audio_clips , dim = - 1 )
torchaudio . save ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' , audio , args . output_sample_rate )
audio = audio . squeeze ( 0 ) . cpu ( )
audio_cache [ name ] = {
' audio ' : audio ,
' settings ' : get_info ( voice = voice ) ,
' output ' : True
}
else :
name = get_name ( candidate = candidate )
audio_cache [ name ] [ ' output ' ] = True
if args . voice_fixer :
if not voicefixer :
2023-05-04 23:40:33 +00:00
notify_progress ( " Loading voicefix... " , progress = progress )
2023-04-26 04:48:09 +00:00
load_voicefixer ( )
try :
fixed_cache = { }
2023-05-04 23:40:33 +00:00
for name in tqdm ( audio_cache , desc = " Running voicefix... " ) :
2023-04-26 04:48:09 +00:00
del audio_cache [ name ] [ ' audio ' ]
if ' output ' not in audio_cache [ name ] or not audio_cache [ name ] [ ' output ' ] :
continue
path = f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav '
fixed = f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } _fixed.wav '
voicefixer . restore (
input = path ,
output = fixed ,
cuda = get_device_name ( ) == " cuda " and args . voice_fixer_use_cuda ,
#mode=mode,
)
fixed_cache [ f ' { name } _fixed ' ] = {
' settings ' : audio_cache [ name ] [ ' settings ' ] ,
' output ' : True
}
audio_cache [ name ] [ ' output ' ] = False
for name in fixed_cache :
audio_cache [ name ] = fixed_cache [ name ]
except Exception as e :
print ( e )
print ( " \n Failed to run Voicefixer " )
for name in audio_cache :
if ' output ' not in audio_cache [ name ] or not audio_cache [ name ] [ ' output ' ] :
if args . prune_nonfinal_outputs :
audio_cache [ name ] [ ' pruned ' ] = True
os . remove ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' )
continue
output_voices . append ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' )
if not args . embed_output_metadata :
with open ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( audio_cache [ name ] [ ' settings ' ] , indent = ' \t ' ) )
if args . embed_output_metadata :
2023-05-04 23:40:33 +00:00
for name in tqdm ( audio_cache , desc = " Embedding metadata... " ) :
2023-04-26 04:48:09 +00:00
if ' pruned ' in audio_cache [ name ] and audio_cache [ name ] [ ' pruned ' ] :
continue
metadata = music_tag . load_file ( f " { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav " )
metadata [ ' lyrics ' ] = json . dumps ( audio_cache [ name ] [ ' settings ' ] )
metadata . save ( )
if sample_voice is not None :
sample_voice = ( tts . input_sample_rate , sample_voice . numpy ( ) )
info = get_info ( voice = voice , latents = False )
print ( f " Generation took { info [ ' time ' ] } seconds, saved to ' { output_voices [ 0 ] } ' \n " )
info [ ' seed ' ] = usedSeed
if ' latents ' in info :
del info [ ' latents ' ]
os . makedirs ( ' ./config/ ' , exist_ok = True )
with open ( f ' ./config/generate.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( info , indent = ' \t ' ) )
stats = [
[ parameters [ ' seed ' ] , " {:.3f} " . format ( info [ ' time ' ] ) ]
]
return (
sample_voice ,
output_voices ,
stats ,
)
2023-03-31 03:26:00 +00:00
def generate_valle ( * * kwargs ) :
2023-03-09 00:26:47 +00:00
parameters = { }
parameters . update ( kwargs )
voice = parameters [ ' voice ' ]
progress = parameters [ ' progress ' ] if ' progress ' in parameters else None
if parameters [ ' seed ' ] == 0 :
parameters [ ' seed ' ] = None
usedSeed = parameters [ ' seed ' ]
2023-02-17 00:08:27 +00:00
global args
global tts
2023-02-21 03:00:45 +00:00
unload_whisper ( )
unload_voicefixer ( )
2023-02-18 14:10:26 +00:00
if not tts :
2023-02-20 00:21:16 +00:00
# should check if it's loading or unloaded, and load it if it's unloaded
2023-02-21 03:00:45 +00:00
if tts_loading :
raise Exception ( " TTS is still initializing... " )
2023-03-09 00:26:47 +00:00
if progress is not None :
2023-05-04 23:40:33 +00:00
notify_progress ( " Initializing TTS... " , progress = progress )
2023-02-21 03:00:45 +00:00
load_tts ( )
2023-03-07 03:55:35 +00:00
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
2023-02-17 00:08:27 +00:00
2023-02-18 20:37:37 +00:00
do_gc ( )
2023-03-06 21:48:34 +00:00
voice_samples = None
2023-03-31 03:26:00 +00:00
conditioning_latents = None
sample_voice = None
voice_cache = { }
def fetch_voice ( voice ) :
2023-04-12 20:02:46 +00:00
voice_dir = f ' ./training/ { voice } /audio/ '
if not os . path . isdir ( voice_dir ) :
voice_dir = f ' ./voices/ { voice } / '
2023-03-31 03:26:00 +00:00
files = [ f ' { voice_dir } / { d } ' for d in os . listdir ( voice_dir ) if d [ - 4 : ] == " .wav " ]
2023-04-12 20:02:46 +00:00
# return files
return random . choice ( files )
2023-03-31 03:26:00 +00:00
def get_settings ( override = None ) :
settings = {
' ar_temp ' : float ( parameters [ ' temperature ' ] ) ,
' nar_temp ' : float ( parameters [ ' temperature ' ] ) ,
' max_ar_samples ' : parameters [ ' num_autoregressive_samples ' ] ,
}
# could be better to just do a ternary on everything above, but i am not a professional
selected_voice = voice
if override is not None :
if ' voice ' in override :
selected_voice = override [ ' voice ' ]
for k in override :
if k not in settings :
continue
settings [ k ] = override [ k ]
settings [ ' reference ' ] = fetch_voice ( voice = selected_voice )
return settings
if not parameters [ ' delimiter ' ] :
parameters [ ' delimiter ' ] = " \n "
elif parameters [ ' delimiter ' ] == " \\ n " :
parameters [ ' delimiter ' ] = " \n "
if parameters [ ' delimiter ' ] and parameters [ ' delimiter ' ] != " " and parameters [ ' delimiter ' ] in parameters [ ' text ' ] :
texts = parameters [ ' text ' ] . split ( parameters [ ' delimiter ' ] )
else :
texts = split_and_recombine_text ( parameters [ ' text ' ] )
full_start_time = time . time ( )
outdir = f " { args . results_folder } / { voice } / "
os . makedirs ( outdir , exist_ok = True )
audio_cache = { }
volume_adjust = torchaudio . transforms . Vol ( gain = args . output_volume , gain_type = " amplitude " ) if args . output_volume != 1 else None
idx = 0
idx_cache = { }
for i , file in enumerate ( os . listdir ( outdir ) ) :
filename = os . path . basename ( file )
extension = os . path . splitext ( filename ) [ 1 ]
if extension != " .json " and extension != " .wav " :
continue
match = re . findall ( rf " ^ { voice } _( \ d+)(?:.+?)? { extension } $ " , filename )
if match and len ( match ) > 0 :
key = int ( match [ 0 ] )
idx_cache [ key ] = True
if len ( idx_cache ) > 0 :
keys = sorted ( list ( idx_cache . keys ( ) ) )
idx = keys [ - 1 ] + 1
idx = pad ( idx , 4 )
def get_name ( line = 0 , candidate = 0 , combined = False ) :
name = f " { idx } "
if combined :
name = f " { name } _combined "
elif len ( texts ) > 1 :
name = f " { name } _ { line } "
if parameters [ ' candidates ' ] > 1 :
name = f " { name } _ { candidate } "
return name
def get_info ( voice , settings = None , latents = True ) :
info = { }
info . update ( parameters )
info [ ' time ' ] = time . time ( ) - full_start_time
info [ ' datetime ' ] = datetime . now ( ) . isoformat ( )
info [ ' progress ' ] = None
del info [ ' progress ' ]
if info [ ' delimiter ' ] == " \n " :
info [ ' delimiter ' ] = " \\ n "
if settings is not None :
for k in settings :
if k in info :
info [ k ] = settings [ k ]
return info
INFERENCING = True
for line , cut_text in enumerate ( texts ) :
2023-05-04 23:40:33 +00:00
tqdm_prefix = f ' [ { str ( line + 1 ) } / { str ( len ( texts ) ) } ] '
print ( f " { tqdm_prefix } Generating line: { cut_text } " )
2023-03-31 03:26:00 +00:00
start_time = time . time ( )
# do setting editing
match = re . findall ( r ' ^( \ { .+ \ }) (.+?)$ ' , cut_text )
override = None
if match and len ( match ) > 0 :
match = match [ 0 ]
try :
override = json . loads ( match [ 0 ] )
cut_text = match [ 1 ] . strip ( )
except Exception as e :
raise Exception ( " Prompt settings editing requested, but received invalid JSON " )
settings = get_settings ( override = override )
reference = settings [ ' reference ' ]
settings . pop ( " reference " )
gen = tts . inference ( cut_text , reference , * * settings )
run_time = time . time ( ) - start_time
print ( f " Generating line took { run_time } seconds " )
if not isinstance ( gen , list ) :
gen = [ gen ]
for j , g in enumerate ( gen ) :
wav , sr = g
name = get_name ( line = line , candidate = j )
settings [ ' text ' ] = cut_text
settings [ ' time ' ] = run_time
settings [ ' datetime ' ] = datetime . now ( ) . isoformat ( )
# save here in case some error happens mid-batch
2023-04-26 04:48:09 +00:00
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
soundfile . write ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' , wav . cpu ( ) [ 0 , 0 ] , sr )
wav , sr = torchaudio . load ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' )
2023-03-31 03:26:00 +00:00
audio_cache [ name ] = {
' audio ' : wav ,
' settings ' : get_info ( voice = override [ ' voice ' ] if override and ' voice ' in override else voice , settings = settings )
}
del gen
do_gc ( )
INFERENCING = False
for k in audio_cache :
audio = audio_cache [ k ] [ ' audio ' ]
audio , _ = resample ( audio , tts . output_sample_rate , args . output_sample_rate )
if volume_adjust is not None :
audio = volume_adjust ( audio )
audio_cache [ k ] [ ' audio ' ] = audio
2023-04-26 04:48:09 +00:00
torchaudio . save ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { k } .wav ' , audio , args . output_sample_rate )
2023-03-31 03:26:00 +00:00
output_voices = [ ]
for candidate in range ( parameters [ ' candidates ' ] ) :
if len ( texts ) > 1 :
audio_clips = [ ]
for line in range ( len ( texts ) ) :
name = get_name ( line = line , candidate = candidate )
audio = audio_cache [ name ] [ ' audio ' ]
audio_clips . append ( audio )
name = get_name ( candidate = candidate , combined = True )
audio = torch . cat ( audio_clips , dim = - 1 )
2023-04-26 04:48:09 +00:00
torchaudio . save ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' , audio , args . output_sample_rate )
2023-03-31 03:26:00 +00:00
audio = audio . squeeze ( 0 ) . cpu ( )
audio_cache [ name ] = {
' audio ' : audio ,
' settings ' : get_info ( voice = voice ) ,
' output ' : True
}
else :
name = get_name ( candidate = candidate )
audio_cache [ name ] [ ' output ' ] = True
if args . voice_fixer :
if not voicefixer :
2023-05-04 23:40:33 +00:00
notify_progress ( " Loading voicefix... " , progress = progress )
2023-03-31 03:26:00 +00:00
load_voicefixer ( )
try :
fixed_cache = { }
2023-05-04 23:40:33 +00:00
for name in tqdm ( audio_cache , desc = " Running voicefix... " ) :
2023-03-31 03:26:00 +00:00
del audio_cache [ name ] [ ' audio ' ]
if ' output ' not in audio_cache [ name ] or not audio_cache [ name ] [ ' output ' ] :
continue
2023-04-26 04:48:09 +00:00
path = f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav '
fixed = f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } _fixed.wav '
2023-03-31 03:26:00 +00:00
voicefixer . restore (
input = path ,
output = fixed ,
cuda = get_device_name ( ) == " cuda " and args . voice_fixer_use_cuda ,
#mode=mode,
)
fixed_cache [ f ' { name } _fixed ' ] = {
' settings ' : audio_cache [ name ] [ ' settings ' ] ,
' output ' : True
}
audio_cache [ name ] [ ' output ' ] = False
for name in fixed_cache :
audio_cache [ name ] = fixed_cache [ name ]
except Exception as e :
print ( e )
print ( " \n Failed to run Voicefixer " )
for name in audio_cache :
if ' output ' not in audio_cache [ name ] or not audio_cache [ name ] [ ' output ' ] :
if args . prune_nonfinal_outputs :
audio_cache [ name ] [ ' pruned ' ] = True
2023-04-26 04:48:09 +00:00
os . remove ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' )
2023-03-31 03:26:00 +00:00
continue
2023-04-26 04:48:09 +00:00
output_voices . append ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' )
2023-03-31 03:26:00 +00:00
if not args . embed_output_metadata :
2023-04-26 04:48:09 +00:00
with open ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .json ' , ' w ' , encoding = " utf-8 " ) as f :
2023-03-31 03:26:00 +00:00
f . write ( json . dumps ( audio_cache [ name ] [ ' settings ' ] , indent = ' \t ' ) )
if args . embed_output_metadata :
2023-05-04 23:40:33 +00:00
for name in tqdm ( audio_cache , desc = " Embedding metadata... " ) :
2023-03-31 03:26:00 +00:00
if ' pruned ' in audio_cache [ name ] and audio_cache [ name ] [ ' pruned ' ] :
continue
2023-04-26 04:48:09 +00:00
metadata = music_tag . load_file ( f " { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav " )
2023-03-31 03:26:00 +00:00
metadata [ ' lyrics ' ] = json . dumps ( audio_cache [ name ] [ ' settings ' ] )
metadata . save ( )
if sample_voice is not None :
sample_voice = ( tts . input_sample_rate , sample_voice . numpy ( ) )
info = get_info ( voice = voice , latents = False )
print ( f " Generation took { info [ ' time ' ] } seconds, saved to ' { output_voices [ 0 ] } ' \n " )
info [ ' seed ' ] = usedSeed
if ' latents ' in info :
del info [ ' latents ' ]
os . makedirs ( ' ./config/ ' , exist_ok = True )
with open ( f ' ./config/generate.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( info , indent = ' \t ' ) )
stats = [
[ parameters [ ' seed ' ] , " {:.3f} " . format ( info [ ' time ' ] ) ]
]
return (
sample_voice ,
output_voices ,
stats ,
)
def generate_tortoise ( * * kwargs ) :
parameters = { }
parameters . update ( kwargs )
voice = parameters [ ' voice ' ]
progress = parameters [ ' progress ' ] if ' progress ' in parameters else None
if parameters [ ' seed ' ] == 0 :
parameters [ ' seed ' ] = None
usedSeed = parameters [ ' seed ' ]
global args
global tts
unload_whisper ( )
unload_voicefixer ( )
if not tts :
# should check if it's loading or unloaded, and load it if it's unloaded
if tts_loading :
raise Exception ( " TTS is still initializing... " )
load_tts ( )
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
do_gc ( )
voice_samples = None
conditioning_latents = None
2023-03-06 21:48:34 +00:00
sample_voice = None
2023-02-17 00:08:27 +00:00
2023-03-07 17:04:45 +00:00
voice_cache = { }
2023-03-07 05:35:21 +00:00
def fetch_voice ( voice ) :
2023-03-07 17:04:45 +00:00
cache_key = f ' { voice } : { tts . autoregressive_model_hash [ : 8 ] } '
if cache_key in voice_cache :
return voice_cache [ cache_key ]
2023-03-06 21:48:34 +00:00
2023-03-09 04:33:12 +00:00
print ( f " Loading voice: { voice } with model { tts . autoregressive_model_hash [ : 8 ] } " )
2023-03-07 05:35:21 +00:00
sample_voice = None
2023-03-06 21:48:34 +00:00
if voice == " microphone " :
2023-03-09 00:26:47 +00:00
if parameters [ ' mic_audio ' ] is None :
2023-03-06 21:48:34 +00:00
raise Exception ( " Please provide audio from mic when choosing `microphone` as a voice input " )
2023-03-09 00:26:47 +00:00
voice_samples , conditioning_latents = [ load_audio ( parameters [ ' mic_audio ' ] , tts . input_sample_rate ) ] , None
2023-03-06 21:48:34 +00:00
elif voice == " random " :
voice_samples , conditioning_latents = None , tts . get_random_conditioning_latents ( )
else :
2023-03-07 05:35:21 +00:00
if progress is not None :
2023-05-04 23:40:33 +00:00
notify_progress ( f " Loading voice: { voice } " , progress = progress )
2023-03-06 21:48:34 +00:00
2023-03-07 05:35:21 +00:00
voice_samples , conditioning_latents = load_voice ( voice , model_hash = tts . autoregressive_model_hash )
2023-03-06 21:48:34 +00:00
if voice_samples and len ( voice_samples ) > 0 :
2023-03-07 05:35:21 +00:00
if conditioning_latents is None :
2023-03-09 00:26:47 +00:00
conditioning_latents = compute_latents ( voice = voice , voice_samples = voice_samples , voice_latents_chunks = parameters [ ' voice_latents_chunks ' ] )
2023-03-07 05:35:21 +00:00
2023-03-06 21:48:34 +00:00
sample_voice = torch . cat ( voice_samples , dim = - 1 ) . squeeze ( ) . cpu ( )
voice_samples = None
2023-02-17 00:08:27 +00:00
2023-03-07 17:04:45 +00:00
voice_cache [ cache_key ] = ( voice_samples , conditioning_latents , sample_voice )
return voice_cache [ cache_key ]
2023-03-07 04:34:39 +00:00
2023-03-06 21:48:34 +00:00
def get_settings ( override = None ) :
settings = {
2023-03-09 00:26:47 +00:00
' temperature ' : float ( parameters [ ' temperature ' ] ) ,
2023-03-06 21:48:34 +00:00
2023-03-09 00:26:47 +00:00
' top_p ' : float ( parameters [ ' top_p ' ] ) ,
' diffusion_temperature ' : float ( parameters [ ' diffusion_temperature ' ] ) ,
' length_penalty ' : float ( parameters [ ' length_penalty ' ] ) ,
' repetition_penalty ' : float ( parameters [ ' repetition_penalty ' ] ) ,
' cond_free_k ' : float ( parameters [ ' cond_free_k ' ] ) ,
2023-03-06 21:48:34 +00:00
2023-03-09 00:26:47 +00:00
' num_autoregressive_samples ' : parameters [ ' num_autoregressive_samples ' ] ,
2023-03-06 21:48:34 +00:00
' sample_batch_size ' : args . sample_batch_size ,
2023-03-09 00:26:47 +00:00
' diffusion_iterations ' : parameters [ ' diffusion_iterations ' ] ,
2023-03-06 21:48:34 +00:00
2023-03-07 05:35:21 +00:00
' voice_samples ' : None ,
' conditioning_latents ' : None ,
2023-03-06 21:48:34 +00:00
2023-03-09 00:26:47 +00:00
' use_deterministic_seed ' : parameters [ ' seed ' ] ,
2023-03-06 21:48:34 +00:00
' return_deterministic_state ' : True ,
2023-03-09 00:26:47 +00:00
' k ' : parameters [ ' candidates ' ] ,
' diffusion_sampler ' : parameters [ ' diffusion_sampler ' ] ,
' breathing_room ' : parameters [ ' breathing_room ' ] ,
' progress ' : parameters [ ' progress ' ] ,
' half_p ' : " Half Precision " in parameters [ ' experimentals ' ] ,
' cond_free ' : " Conditioning-Free " in parameters [ ' experimentals ' ] ,
' cvvp_amount ' : parameters [ ' cvvp_weight ' ] ,
2023-03-15 00:37:38 +00:00
2023-03-07 05:35:21 +00:00
' autoregressive_model ' : args . autoregressive_model ,
2023-03-15 00:37:38 +00:00
' diffusion_model ' : args . diffusion_model ,
' tokenizer_json ' : args . tokenizer_json ,
2023-03-06 21:48:34 +00:00
}
2023-02-17 00:08:27 +00:00
2023-03-06 21:48:34 +00:00
# could be better to just do a ternary on everything above, but i am not a professional
2023-03-07 05:35:21 +00:00
selected_voice = voice
2023-03-06 21:48:34 +00:00
if override is not None :
if ' voice ' in override :
2023-03-07 05:35:21 +00:00
selected_voice = override [ ' voice ' ]
2023-03-06 21:48:34 +00:00
for k in override :
if k not in settings :
continue
settings [ k ] = override [ k ]
2023-02-17 00:08:27 +00:00
2023-03-07 05:35:21 +00:00
if settings [ ' autoregressive_model ' ] is not None :
if settings [ ' autoregressive_model ' ] == " auto " :
settings [ ' autoregressive_model ' ] = deduce_autoregressive_model ( selected_voice )
tts . load_autoregressive_model ( settings [ ' autoregressive_model ' ] )
2023-03-15 00:37:38 +00:00
if settings [ ' diffusion_model ' ] is not None :
if settings [ ' diffusion_model ' ] == " auto " :
settings [ ' diffusion_model ' ] = deduce_diffusion_model ( selected_voice )
tts . load_diffusion_model ( settings [ ' diffusion_model ' ] )
if settings [ ' tokenizer_json ' ] is not None :
tts . load_tokenizer_json ( settings [ ' tokenizer_json ' ] )
2023-03-07 05:35:21 +00:00
settings [ ' voice_samples ' ] , settings [ ' conditioning_latents ' ] , _ = fetch_voice ( voice = selected_voice )
2023-03-06 21:48:34 +00:00
# clamp it down for the insane users who want this
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
2023-03-09 00:26:47 +00:00
settings [ ' sample_batch_size ' ] = args . sample_batch_size
if not settings [ ' sample_batch_size ' ] :
settings [ ' sample_batch_size ' ] = tts . autoregressive_batch_size
if settings [ ' num_autoregressive_samples ' ] < settings [ ' sample_batch_size ' ] :
settings [ ' sample_batch_size ' ] = settings [ ' num_autoregressive_samples ' ]
2023-03-07 05:35:21 +00:00
if settings [ ' conditioning_latents ' ] is not None and len ( settings [ ' conditioning_latents ' ] ) == 2 and settings [ ' cvvp_amount ' ] > 0 :
2023-03-07 13:40:41 +00:00
print ( " Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents with ' Slimmer voice latents ' unchecked. " )
2023-03-07 05:35:21 +00:00
settings [ ' cvvp_amount ' ] = 0
2023-03-06 21:48:34 +00:00
return settings
2023-02-21 21:50:05 +00:00
2023-03-09 00:26:47 +00:00
if not parameters [ ' delimiter ' ] :
parameters [ ' delimiter ' ] = " \n "
elif parameters [ ' delimiter ' ] == " \\ n " :
parameters [ ' delimiter ' ] = " \n "
2023-02-17 00:08:27 +00:00
2023-03-09 00:26:47 +00:00
if parameters [ ' delimiter ' ] and parameters [ ' delimiter ' ] != " " and parameters [ ' delimiter ' ] in parameters [ ' text ' ] :
texts = parameters [ ' text ' ] . split ( parameters [ ' delimiter ' ] )
2023-02-17 00:08:27 +00:00
else :
2023-03-09 00:26:47 +00:00
texts = split_and_recombine_text ( parameters [ ' text ' ] )
2023-02-17 00:08:27 +00:00
full_start_time = time . time ( )
2023-03-19 22:05:33 +00:00
outdir = f " { args . results_folder } / { voice } / "
2023-02-17 00:08:27 +00:00
os . makedirs ( outdir , exist_ok = True )
audio_cache = { }
volume_adjust = torchaudio . transforms . Vol ( gain = args . output_volume , gain_type = " amplitude " ) if args . output_volume != 1 else None
idx = 0
idx_cache = { }
for i , file in enumerate ( os . listdir ( outdir ) ) :
filename = os . path . basename ( file )
extension = os . path . splitext ( filename ) [ 1 ]
if extension != " .json " and extension != " .wav " :
continue
match = re . findall ( rf " ^ { voice } _( \ d+)(?:.+?)? { extension } $ " , filename )
2023-03-09 04:06:07 +00:00
if match and len ( match ) > 0 :
key = int ( match [ 0 ] )
idx_cache [ key ] = True
2023-02-17 00:08:27 +00:00
if len ( idx_cache ) > 0 :
keys = sorted ( list ( idx_cache . keys ( ) ) )
idx = keys [ - 1 ] + 1
2023-02-17 05:42:55 +00:00
idx = pad ( idx , 4 )
2023-02-17 00:08:27 +00:00
def get_name ( line = 0 , candidate = 0 , combined = False ) :
name = f " { idx } "
if combined :
name = f " { name } _combined "
elif len ( texts ) > 1 :
name = f " { name } _ { line } "
2023-03-09 00:26:47 +00:00
if parameters [ ' candidates ' ] > 1 :
2023-02-17 00:08:27 +00:00
name = f " { name } _ { candidate } "
return name
2023-03-06 23:07:16 +00:00
def get_info ( voice , settings = None , latents = True ) :
2023-03-09 00:26:47 +00:00
info = { }
info . update ( parameters )
2023-03-31 03:26:00 +00:00
info [ ' time ' ] = time . time ( ) - full_start_time
2023-03-09 02:43:05 +00:00
info [ ' datetime ' ] = datetime . now ( ) . isoformat ( )
2023-03-31 03:26:00 +00:00
2023-03-09 02:43:05 +00:00
info [ ' model ' ] = tts . autoregressive_model_path
2023-03-09 00:26:47 +00:00
info [ ' model_hash ' ] = tts . autoregressive_model_hash
2023-03-31 03:26:00 +00:00
2023-03-09 00:26:47 +00:00
info [ ' progress ' ] = None
del info [ ' progress ' ]
if info [ ' delimiter ' ] == " \n " :
info [ ' delimiter ' ] = " \\ n "
2023-03-06 23:07:16 +00:00
if settings is not None :
for k in settings :
if k in info :
info [ k ] = settings [ k ]
if ' half_p ' in settings and ' cond_free ' in settings :
info [ ' experimentals ' ] = [ ]
if settings [ ' half_p ' ] :
info [ ' experimentals ' ] . append ( " Half Precision " )
if settings [ ' cond_free ' ] :
info [ ' experimentals ' ] . append ( " Conditioning-Free " )
if latents and " latents " not in info :
voice = info [ ' voice ' ]
2023-03-09 04:23:36 +00:00
model_hash = settings [ " model_hash " ] [ : 8 ] if settings is not None and " model_hash " in settings else tts . autoregressive_model_hash [ : 8 ]
dir = f ' { get_voice_dir ( ) } / { voice } / '
latents_path = f ' { dir } /cond_latents_ { model_hash } .pth '
2023-03-06 23:07:16 +00:00
if voice == " random " or voice == " microphone " :
2023-03-08 16:09:29 +00:00
if latents and settings is not None and settings [ ' conditioning_latents ' ] :
2023-03-09 04:23:36 +00:00
os . makedirs ( dir , exist_ok = True )
2023-03-06 23:07:16 +00:00
torch . save ( conditioning_latents , latents_path )
if latents_path and os . path . exists ( latents_path ) :
try :
with open ( latents_path , ' rb ' ) as f :
info [ ' latents ' ] = base64 . b64encode ( f . read ( ) ) . decode ( " ascii " )
except Exception as e :
pass
return info
2023-03-14 17:42:42 +00:00
INFERENCING = True
2023-02-17 00:08:27 +00:00
for line , cut_text in enumerate ( texts ) :
2023-03-17 05:33:49 +00:00
if should_phonemize ( ) :
cut_text = phonemizer ( cut_text )
2023-03-09 00:26:47 +00:00
if parameters [ ' emotion ' ] == " Custom " :
if parameters [ ' prompt ' ] and parameters [ ' prompt ' ] . strip ( ) != " " :
cut_text = f " [ { parameters [ ' prompt ' ] } ,] { cut_text } "
elif parameters [ ' emotion ' ] != " None " and parameters [ ' emotion ' ] :
cut_text = f " [I am really { parameters [ ' emotion ' ] . lower ( ) } ,] { cut_text } "
2023-03-06 21:48:34 +00:00
2023-05-04 23:40:33 +00:00
tqdm_prefix = f ' [ { str ( line + 1 ) } / { str ( len ( texts ) ) } ] '
print ( f " { tqdm_prefix } Generating line: { cut_text } " )
2023-02-17 00:08:27 +00:00
start_time = time . time ( )
2023-03-06 21:48:34 +00:00
# do setting editing
match = re . findall ( r ' ^( \ { .+ \ }) (.+?)$ ' , cut_text )
2023-03-06 23:07:16 +00:00
override = None
2023-03-06 21:48:34 +00:00
if match and len ( match ) > 0 :
match = match [ 0 ]
try :
override = json . loads ( match [ 0 ] )
2023-03-07 05:35:21 +00:00
cut_text = match [ 1 ] . strip ( )
2023-03-06 21:48:34 +00:00
except Exception as e :
raise Exception ( " Prompt settings editing requested, but received invalid JSON " )
2023-03-07 05:35:21 +00:00
settings = get_settings ( override = override )
gen , additionals = tts . tts ( cut_text , * * settings )
2023-03-06 21:48:34 +00:00
2023-03-09 00:26:47 +00:00
parameters [ ' seed ' ] = additionals [ 0 ]
2023-02-17 00:08:27 +00:00
run_time = time . time ( ) - start_time
print ( f " Generating line took { run_time } seconds " )
2023-03-09 00:26:47 +00:00
2023-02-17 00:08:27 +00:00
if not isinstance ( gen , list ) :
gen = [ gen ]
for j , g in enumerate ( gen ) :
audio = g . squeeze ( 0 ) . cpu ( )
name = get_name ( line = line , candidate = j )
2023-03-06 23:07:16 +00:00
2023-03-07 05:35:21 +00:00
settings [ ' text ' ] = cut_text
settings [ ' time ' ] = run_time
2023-03-31 03:26:00 +00:00
settings [ ' datetime ' ] = datetime . now ( ) . isoformat ( )
if args . tts_backend == " tortoise " :
settings [ ' model ' ] = tts . autoregressive_model_path
settings [ ' model_hash ' ] = tts . autoregressive_model_hash
2023-03-06 23:07:16 +00:00
2023-02-17 00:08:27 +00:00
audio_cache [ name ] = {
' audio ' : audio ,
2023-03-07 05:35:21 +00:00
' settings ' : get_info ( voice = override [ ' voice ' ] if override and ' voice ' in override else voice , settings = settings )
2023-02-17 00:08:27 +00:00
}
# save here in case some error happens mid-batch
2023-04-26 04:48:09 +00:00
torchaudio . save ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' , audio , tts . output_sample_rate )
2023-02-17 00:08:27 +00:00
2023-02-18 20:37:37 +00:00
del gen
do_gc ( )
2023-03-14 17:42:42 +00:00
INFERENCING = False
2023-02-18 20:37:37 +00:00
2023-02-17 00:08:27 +00:00
for k in audio_cache :
audio = audio_cache [ k ] [ ' audio ' ]
2023-03-13 01:20:55 +00:00
audio , _ = resample ( audio , tts . output_sample_rate , args . output_sample_rate )
2023-02-17 00:08:27 +00:00
if volume_adjust is not None :
audio = volume_adjust ( audio )
audio_cache [ k ] [ ' audio ' ] = audio
2023-04-26 04:48:09 +00:00
torchaudio . save ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { k } .wav ' , audio , args . output_sample_rate )
2023-03-06 23:07:16 +00:00
2023-02-17 00:08:27 +00:00
output_voices = [ ]
2023-03-09 00:26:47 +00:00
for candidate in range ( parameters [ ' candidates ' ] ) :
2023-02-17 00:08:27 +00:00
if len ( texts ) > 1 :
audio_clips = [ ]
for line in range ( len ( texts ) ) :
name = get_name ( line = line , candidate = candidate )
audio = audio_cache [ name ] [ ' audio ' ]
audio_clips . append ( audio )
name = get_name ( candidate = candidate , combined = True )
audio = torch . cat ( audio_clips , dim = - 1 )
2023-04-26 04:48:09 +00:00
torchaudio . save ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' , audio , args . output_sample_rate )
2023-02-17 00:08:27 +00:00
audio = audio . squeeze ( 0 ) . cpu ( )
audio_cache [ name ] = {
' audio ' : audio ,
2023-03-06 23:07:16 +00:00
' settings ' : get_info ( voice = voice ) ,
2023-02-17 00:08:27 +00:00
' output ' : True
}
else :
name = get_name ( candidate = candidate )
audio_cache [ name ] [ ' output ' ] = True
2023-02-20 00:21:16 +00:00
if args . voice_fixer :
if not voicefixer :
2023-05-04 23:40:33 +00:00
notify_progress ( " Loading voicefix... " , progress = progress )
2023-02-20 00:21:16 +00:00
load_voicefixer ( )
2023-03-08 04:12:22 +00:00
try :
fixed_cache = { }
2023-05-04 23:40:33 +00:00
for name in tqdm ( audio_cache , desc = " Running voicefix... " ) :
2023-03-08 04:12:22 +00:00
del audio_cache [ name ] [ ' audio ' ]
if ' output ' not in audio_cache [ name ] or not audio_cache [ name ] [ ' output ' ] :
continue
2023-02-21 21:50:05 +00:00
2023-04-26 04:48:09 +00:00
path = f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav '
fixed = f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } _fixed.wav '
2023-03-08 04:12:22 +00:00
voicefixer . restore (
input = path ,
output = fixed ,
cuda = get_device_name ( ) == " cuda " and args . voice_fixer_use_cuda ,
#mode=mode,
)
fixed_cache [ f ' { name } _fixed ' ] = {
' settings ' : audio_cache [ name ] [ ' settings ' ] ,
' output ' : True
}
audio_cache [ name ] [ ' output ' ] = False
2023-02-21 21:50:05 +00:00
2023-03-08 04:12:22 +00:00
for name in fixed_cache :
audio_cache [ name ] = fixed_cache [ name ]
except Exception as e :
print ( e )
print ( " \n Failed to run Voicefixer " )
2023-02-21 21:50:05 +00:00
for name in audio_cache :
if ' output ' not in audio_cache [ name ] or not audio_cache [ name ] [ ' output ' ] :
if args . prune_nonfinal_outputs :
audio_cache [ name ] [ ' pruned ' ] = True
2023-04-26 04:48:09 +00:00
os . remove ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' )
2023-02-21 21:50:05 +00:00
continue
2023-04-26 04:48:09 +00:00
output_voices . append ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav ' )
2023-02-21 21:50:05 +00:00
if not args . embed_output_metadata :
2023-04-26 04:48:09 +00:00
with open ( f ' { outdir } / { cleanup_voice_name ( voice ) } _ { name } .json ' , ' w ' , encoding = " utf-8 " ) as f :
2023-03-06 23:07:16 +00:00
f . write ( json . dumps ( audio_cache [ name ] [ ' settings ' ] , indent = ' \t ' ) )
2023-02-17 00:08:27 +00:00
if args . embed_output_metadata :
2023-05-04 23:40:33 +00:00
for name in tqdm ( audio_cache , desc = " Embedding metadata... " ) :
2023-02-21 21:50:05 +00:00
if ' pruned ' in audio_cache [ name ] and audio_cache [ name ] [ ' pruned ' ] :
continue
2023-04-26 04:48:09 +00:00
metadata = music_tag . load_file ( f " { outdir } / { cleanup_voice_name ( voice ) } _ { name } .wav " )
2023-03-06 23:07:16 +00:00
metadata [ ' lyrics ' ] = json . dumps ( audio_cache [ name ] [ ' settings ' ] )
2023-02-17 00:08:27 +00:00
metadata . save ( )
if sample_voice is not None :
sample_voice = ( tts . input_sample_rate , sample_voice . numpy ( ) )
2023-03-06 23:07:16 +00:00
info = get_info ( voice = voice , latents = False )
2023-02-17 00:08:27 +00:00
print ( f " Generation took { info [ ' time ' ] } seconds, saved to ' { output_voices [ 0 ] } ' \n " )
2023-03-09 00:26:47 +00:00
info [ ' seed ' ] = usedSeed
2023-02-17 00:08:27 +00:00
if ' latents ' in info :
del info [ ' latents ' ]
2023-02-17 20:10:27 +00:00
os . makedirs ( ' ./config/ ' , exist_ok = True )
2023-02-17 00:08:27 +00:00
with open ( f ' ./config/generate.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( info , indent = ' \t ' ) )
stats = [
2023-03-09 00:26:47 +00:00
[ parameters [ ' seed ' ] , " {:.3f} " . format ( info [ ' time ' ] ) ]
2023-02-17 00:08:27 +00:00
]
return (
sample_voice ,
output_voices ,
stats ,
)
2023-02-20 00:21:16 +00:00
def cancel_generate ( ) :
2023-03-14 17:42:42 +00:00
if not INFERENCING :
return
2023-02-24 23:13:13 +00:00
import tortoise . api
2023-03-14 17:42:42 +00:00
2023-02-24 23:13:13 +00:00
tortoise . api . STOP_SIGNAL = True
2023-02-17 20:10:27 +00:00
2023-03-02 00:46:52 +00:00
def hash_file ( path , algo = " md5 " , buffer_size = 0 ) :
hash = None
if algo == " md5 " :
hash = hashlib . md5 ( )
elif algo == " sha1 " :
hash = hashlib . sha1 ( )
else :
raise Exception ( f ' Unknown hash algorithm specified: { algo } ' )
if not os . path . exists ( path ) :
raise Exception ( f ' Path not found: { path } ' )
with open ( path , ' rb ' ) as f :
if buffer_size > 0 :
while True :
data = f . read ( buffer_size )
if not data :
break
hash . update ( data )
else :
hash . update ( f . read ( ) )
return " {0} " . format ( hash . hexdigest ( ) )
2023-03-03 21:13:48 +00:00
def update_baseline_for_latents_chunks ( voice ) :
2023-03-07 04:34:39 +00:00
global current_voice
current_voice = voice
2023-03-03 21:13:48 +00:00
path = f ' { get_voice_dir ( ) } / { voice } / '
if not os . path . isdir ( path ) :
return 1
2023-03-07 03:55:35 +00:00
dataset_file = f ' ./training/ { voice } /train.txt '
if os . path . exists ( dataset_file ) :
return 0 # 0 will leverage using the LJspeech dataset for computing latents
2023-03-03 21:13:48 +00:00
files = os . listdir ( path )
2023-03-05 23:55:27 +00:00
total = 0
2023-03-03 21:13:48 +00:00
total_duration = 0
2023-03-05 23:55:27 +00:00
2023-03-03 21:13:48 +00:00
for file in files :
if file [ - 4 : ] != " .wav " :
continue
2023-03-05 23:55:27 +00:00
2023-03-03 21:13:48 +00:00
metadata = torchaudio . info ( f ' { path } / { file } ' )
2023-03-13 21:24:51 +00:00
duration = metadata . num_frames / metadata . sample_rate
2023-03-03 21:13:48 +00:00
total_duration + = duration
2023-03-05 23:55:27 +00:00
total = total + 1
2023-03-03 21:13:48 +00:00
2023-03-07 03:55:35 +00:00
# brain too fried to figure out a better way
2023-03-05 23:55:27 +00:00
if args . autocalculate_voice_chunk_duration_size == 0 :
return int ( total_duration / total ) if total > 0 else 1
2023-03-03 21:13:48 +00:00
return int ( total_duration / args . autocalculate_voice_chunk_duration_size ) if total_duration > 0 else 1
2023-03-07 03:55:35 +00:00
def compute_latents ( voice = None , voice_samples = None , voice_latents_chunks = 0 , progress = None ) :
2023-02-20 00:21:16 +00:00
global tts
global args
2023-02-21 03:00:45 +00:00
2023-02-20 00:21:16 +00:00
unload_whisper ( )
unload_voicefixer ( )
2023-02-21 03:00:45 +00:00
if not tts :
if tts_loading :
raise Exception ( " TTS is still initializing... " )
load_tts ( )
2023-03-07 03:55:35 +00:00
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
2023-03-07 04:34:39 +00:00
if args . autoregressive_model == " auto " :
tts . load_autoregressive_model ( deduce_autoregressive_model ( voice ) )
2023-03-07 03:55:35 +00:00
if voice :
load_from_dataset = voice_latents_chunks == 0
if load_from_dataset :
dataset_path = f ' ./training/ { voice } /train.txt '
if not os . path . exists ( dataset_path ) :
load_from_dataset = False
else :
with open ( dataset_path , ' r ' , encoding = " utf-8 " ) as f :
lines = f . readlines ( )
2023-03-11 16:46:03 +00:00
print ( " Leveraging dataset for computing latents " )
2023-03-07 03:55:35 +00:00
voice_samples = [ ]
max_length = 0
for line in lines :
filename = f ' ./training/ { voice } / { line . split ( " | " ) [ 0 ] } '
waveform = load_audio ( filename , 22050 )
max_length = max ( max_length , waveform . shape [ - 1 ] )
voice_samples . append ( waveform )
for i in range ( len ( voice_samples ) ) :
voice_samples [ i ] = pad_or_truncate ( voice_samples [ i ] , max_length )
voice_latents_chunks = len ( voice_samples )
2023-03-11 16:46:03 +00:00
if voice_latents_chunks == 0 :
print ( " Dataset is empty! " )
load_from_dataset = True
2023-03-07 03:55:35 +00:00
if not load_from_dataset :
voice_samples , _ = load_voice ( voice , load_latents = False )
2023-02-20 00:21:16 +00:00
if voice_samples is None :
return
2023-05-04 23:40:33 +00:00
conditioning_latents = tts . get_conditioning_latents ( voice_samples , return_mels = not args . latents_lean_and_mean , slices = voice_latents_chunks , force_cpu = args . force_cpu_for_conditioning_latents )
2023-02-20 00:21:16 +00:00
if len ( conditioning_latents ) == 4 :
conditioning_latents = ( conditioning_latents [ 0 ] , conditioning_latents [ 1 ] , conditioning_latents [ 2 ] , None )
2023-03-07 05:35:21 +00:00
outfile = f ' { get_voice_dir ( ) } / { voice } /cond_latents_ { tts . autoregressive_model_hash [ : 8 ] } .pth '
torch . save ( conditioning_latents , outfile )
print ( f ' Saved voice latents: { outfile } ' )
2023-02-20 00:21:16 +00:00
2023-03-07 03:55:35 +00:00
return conditioning_latents
2023-02-17 19:06:05 +00:00
2023-02-23 06:24:54 +00:00
# superfluous, but it cleans up some things
class TrainingState ( ) :
2023-03-09 00:26:47 +00:00
def __init__ ( self , config_path , keep_x_past_checkpoints = 0 , start = True ) :
2023-03-17 05:33:49 +00:00
self . killed = False
self . training_dir = os . path . dirname ( config_path )
2023-02-23 06:24:54 +00:00
with open ( config_path , ' r ' ) as file :
2023-03-17 05:33:49 +00:00
self . yaml_config = yaml . safe_load ( file )
2023-02-17 20:10:27 +00:00
2023-03-17 05:33:49 +00:00
self . json_config = json . load ( open ( f " { self . training_dir } /train.json " , ' r ' , encoding = " utf-8 " ) )
self . dataset_path = f " { self . training_dir } /train.txt "
with open ( self . dataset_path , ' r ' , encoding = " utf-8 " ) as f :
self . dataset_size = len ( f . readlines ( ) )
2023-03-09 00:26:47 +00:00
2023-03-17 05:33:49 +00:00
self . batch_size = self . json_config [ " batch_size " ]
self . save_rate = self . json_config [ " save_rate " ]
self . epoch = 0
self . epochs = self . json_config [ " epochs " ]
2023-03-14 15:48:09 +00:00
self . it = 0
2023-03-17 05:33:49 +00:00
self . its = calc_iterations ( self . epochs , self . dataset_size , self . batch_size )
2023-03-10 22:35:32 +00:00
self . step = 0
2023-03-17 05:33:49 +00:00
self . steps = int ( self . its / self . dataset_size )
2023-02-23 06:24:54 +00:00
self . checkpoint = 0
2023-03-17 05:33:49 +00:00
self . checkpoints = int ( ( self . its - self . it ) / self . save_rate )
2023-03-14 15:48:09 +00:00
2023-03-17 05:33:49 +00:00
self . gpus = self . json_config [ ' gpus ' ]
2023-03-14 15:48:09 +00:00
2023-02-23 06:24:54 +00:00
self . buffer = [ ]
2023-02-19 05:05:30 +00:00
2023-02-23 06:24:54 +00:00
self . open_state = False
self . training_started = False
2023-02-19 05:05:30 +00:00
2023-03-11 01:19:49 +00:00
self . info = { }
2023-02-25 13:55:25 +00:00
self . it_rate = " "
2023-03-11 01:19:49 +00:00
self . it_rates = 0
2023-03-12 14:47:48 +00:00
self . epoch_rate = " "
2023-02-25 13:55:25 +00:00
2023-02-23 06:24:54 +00:00
self . eta = " ? "
2023-02-23 15:31:43 +00:00
self . eta_hhmmss = " ? "
2023-02-20 22:56:39 +00:00
2023-03-05 23:55:27 +00:00
self . nan_detected = False
2023-03-02 00:46:52 +00:00
self . last_info_check_at = 0
2023-03-09 05:54:08 +00:00
self . statistics = {
' loss ' : [ ] ,
' lr ' : [ ] ,
2023-03-23 15:42:51 +00:00
' grad_norm ' : [ ] ,
2023-03-09 05:54:08 +00:00
}
2023-02-28 06:18:18 +00:00
self . losses = [ ]
2023-03-05 07:37:27 +00:00
self . metrics = {
' step ' : " " ,
' rate ' : " " ,
' loss ' : " " ,
}
2023-02-28 01:01:50 +00:00
2023-03-05 06:45:07 +00:00
self . loss_milestones = [ 1.0 , 0.15 , 0.05 ]
2023-03-23 15:42:51 +00:00
if args . tts_backend == " vall-e " :
self . valle_last_it = 0
self . valle_steps = 0
2023-03-07 20:16:49 +00:00
if keep_x_past_checkpoints > 0 :
self . cleanup_old ( keep = keep_x_past_checkpoints )
2023-03-02 01:35:12 +00:00
if start :
2023-03-17 05:33:49 +00:00
self . spawn_process ( config_path = config_path , gpus = self . gpus )
2023-03-03 04:37:18 +00:00
def spawn_process ( self , config_path , gpus = 1 ) :
2023-03-14 15:48:09 +00:00
if args . tts_backend == " vall-e " :
2023-03-18 15:14:22 +00:00
self . cmd = [ ' deepspeed ' , f ' --num_gpus= { gpus } ' , ' --module ' , ' vall_e.train ' , f ' yaml= " { config_path } " ' ]
2023-03-14 15:48:09 +00:00
else :
2023-03-22 18:21:37 +00:00
self . cmd = [ ' train.bat ' , config_path ] if os . name == " nt " else [ ' ./train.sh ' , config_path ]
2023-02-28 01:01:50 +00:00
2023-02-23 06:24:54 +00:00
print ( " Spawning process: " , " " . join ( self . cmd ) )
2023-03-14 16:04:56 +00:00
self . process = subprocess . Popen ( self . cmd , stdin = subprocess . PIPE , stdout = subprocess . PIPE , stderr = subprocess . STDOUT , universal_newlines = True )
2023-02-20 22:56:39 +00:00
2023-03-11 01:19:49 +00:00
def parse_metrics ( self , data ) :
if isinstance ( data , str ) :
2023-03-22 17:47:23 +00:00
if line . find ( ' Training Metrics: ' ) > = 0 :
data = json . loads ( line . split ( " Training Metrics: " ) [ - 1 ] )
2023-03-11 01:19:49 +00:00
data [ ' mode ' ] = " training "
2023-03-22 17:47:23 +00:00
elif line . find ( ' Validation Metrics: ' ) > = 0 :
data = json . loads ( line . split ( " Validation Metrics: " ) [ - 1 ] )
2023-03-11 01:19:49 +00:00
data [ ' mode ' ] = " validation "
else :
return
self . info = data
if ' epoch ' in self . info :
self . epoch = int ( self . info [ ' epoch ' ] )
if ' it ' in self . info :
self . it = int ( self . info [ ' it ' ] )
if ' step ' in self . info :
self . step = int ( self . info [ ' step ' ] )
if ' steps ' in self . info :
self . steps = int ( self . info [ ' steps ' ] )
2023-03-26 04:08:45 +00:00
if ' elapsed_time ' in self . info :
self . info [ ' iteration_rate ' ] = self . info [ ' elapsed_time ' ]
del self . info [ ' elapsed_time ' ]
2023-03-25 04:12:03 +00:00
2023-03-11 01:19:49 +00:00
if ' iteration_rate ' in self . info :
2023-03-16 04:51:35 +00:00
it_rate = self . info [ ' iteration_rate ' ]
2023-03-11 01:19:49 +00:00
self . it_rate = f ' { " {:.3f} " . format ( 1 / it_rate ) } it/s ' if 0 < it_rate and it_rate < 1 else f ' { " {:.3f} " . format ( it_rate ) } s/it '
self . it_rates + = it_rate
2023-03-31 03:26:00 +00:00
if self . it_rates > 0 and self . it * self . steps > 0 :
epoch_rate = self . it_rates / self . it * self . steps
2023-03-12 14:47:48 +00:00
self . epoch_rate = f ' { " {:.3f} " . format ( 1 / epoch_rate ) } epoch/s ' if 0 < epoch_rate and epoch_rate < 1 else f ' { " {:.3f} " . format ( epoch_rate ) } s/epoch '
2023-03-11 01:19:49 +00:00
try :
2023-03-11 01:37:00 +00:00
self . eta = ( self . its - self . it ) * ( self . it_rates / self . it )
2023-03-11 01:19:49 +00:00
eta = str ( timedelta ( seconds = int ( self . eta ) ) )
self . eta_hhmmss = eta
except Exception as e :
2023-03-11 16:32:35 +00:00
self . eta_hhmmss = " ? "
2023-03-11 01:19:49 +00:00
pass
self . metrics [ ' step ' ] = [ f " { self . epoch } / { self . epochs } " ]
if self . epochs != self . its :
self . metrics [ ' step ' ] . append ( f " { self . it } / { self . its } " )
if self . steps > 1 :
self . metrics [ ' step ' ] . append ( f " { self . step } / { self . steps } " )
self . metrics [ ' step ' ] = " , " . join ( self . metrics [ ' step ' ] )
2023-03-23 15:42:51 +00:00
if args . tts_backend == " tortoise " :
epoch = self . epoch + ( self . step / self . steps )
else :
2023-03-25 02:34:14 +00:00
epoch = self . info [ ' epoch ' ] if ' epoch ' in self . info else self . it
2023-03-22 17:47:23 +00:00
2023-03-23 15:42:51 +00:00
if self . it > 0 :
# probably can double for-loop but whatever
2023-03-26 04:08:45 +00:00
keys = {
' lrs ' : [ ' lr ' ] ,
' losses ' : [ ' loss_text_ce ' , ' loss_mel_ce ' ] ,
2023-03-26 11:05:50 +00:00
' accuracies ' : [ ] ,
2023-05-03 21:31:37 +00:00
' precisions ' : [ ] ,
2023-03-26 11:05:50 +00:00
' grad_norms ' : [ ] ,
2023-03-26 04:08:45 +00:00
}
if args . tts_backend == " vall-e " :
keys [ ' lrs ' ] = [
' ar.lr ' , ' nar.lr ' ,
' ar-half.lr ' , ' nar-half.lr ' ,
' ar-quarter.lr ' , ' nar-quarter.lr ' ,
]
keys [ ' losses ' ] = [
2023-04-12 20:02:46 +00:00
' ar.loss ' , ' nar.loss ' , ' ar+nar.loss ' ,
' ar-half.loss ' , ' nar-half.loss ' , ' ar-half+nar-half.loss ' ,
' ar-quarter.loss ' , ' nar-quarter.loss ' , ' ar-quarter+nar-quarter.loss ' ,
2023-03-26 04:08:45 +00:00
2023-04-26 04:48:09 +00:00
' ar.loss.nll ' , ' nar.loss.nll ' ,
' ar-half.loss.nll ' , ' nar-half.loss.nll ' ,
' ar-quarter.loss.nll ' , ' nar-quarter.loss.nll ' ,
2023-03-26 04:08:45 +00:00
]
keys [ ' accuracies ' ] = [
2023-03-26 04:31:50 +00:00
' ar.loss.acc ' , ' nar.loss.acc ' ,
' ar-half.loss.acc ' , ' nar-half.loss.acc ' ,
' ar-quarter.loss.acc ' , ' nar-quarter.loss.acc ' ,
2023-03-26 04:08:45 +00:00
]
2023-05-03 21:31:37 +00:00
keys [ ' precisions ' ] = [
' ar.loss.precision ' , ' nar.loss.precision ' ,
' ar-half.loss.precision ' , ' nar-half.loss.precision ' ,
' ar-quarter.loss.precision ' , ' nar-quarter.loss.precision ' ,
]
2023-03-26 11:05:50 +00:00
keys [ ' grad_norms ' ] = [ ' ar.grad_norm ' , ' nar.grad_norm ' , ' ar-half.grad_norm ' , ' nar-half.grad_norm ' , ' ar-quarter.grad_norm ' , ' nar-quarter.grad_norm ' ]
2023-03-26 04:08:45 +00:00
for k in keys [ ' lrs ' ] :
2023-03-23 15:42:51 +00:00
if k not in self . info :
continue
2023-03-26 04:08:45 +00:00
2023-03-23 15:42:51 +00:00
self . statistics [ ' lr ' ] . append ( { ' epoch ' : epoch , ' it ' : self . it , ' value ' : self . info [ k ] , ' type ' : k } )
2023-03-26 04:31:50 +00:00
for k in keys [ ' accuracies ' ] :
if k not in self . info :
continue
self . statistics [ ' loss ' ] . append ( { ' epoch ' : epoch , ' it ' : self . it , ' value ' : self . info [ k ] , ' type ' : k } )
2023-05-03 21:31:37 +00:00
for k in keys [ ' precisions ' ] :
if k not in self . info :
continue
self . statistics [ ' loss ' ] . append ( { ' epoch ' : epoch , ' it ' : self . it , ' value ' : self . info [ k ] , ' type ' : k } )
2023-03-22 17:47:23 +00:00
2023-03-26 04:08:45 +00:00
for k in keys [ ' losses ' ] :
2023-03-23 15:42:51 +00:00
if k not in self . info :
continue
2023-03-11 01:19:49 +00:00
2023-03-26 04:31:50 +00:00
prefix = " "
2023-04-12 20:02:46 +00:00
if " mode " in self . info and self . info [ " mode " ] == " validation " :
2023-03-26 04:31:50 +00:00
prefix = f ' { self . info [ " name " ] if " name " in self . info else " val " } _ '
self . statistics [ ' loss ' ] . append ( { ' epoch ' : epoch , ' it ' : self . it , ' value ' : self . info [ k ] , ' type ' : f ' { prefix } { k } ' } )
2023-03-22 17:47:23 +00:00
2023-03-23 15:42:51 +00:00
self . losses . append ( self . statistics [ ' loss ' ] [ - 1 ] )
2023-03-26 11:05:50 +00:00
for k in keys [ ' grad_norms ' ] :
2023-03-23 15:42:51 +00:00
if k not in self . info :
continue
self . statistics [ ' grad_norm ' ] . append ( { ' epoch ' : epoch , ' it ' : self . it , ' value ' : self . info [ k ] , ' type ' : k } )
2023-03-11 01:19:49 +00:00
return data
2023-03-12 15:17:07 +00:00
def get_status ( self ) :
message = None
self . metrics [ ' rate ' ] = [ ]
if self . epoch_rate :
self . metrics [ ' rate ' ] . append ( self . epoch_rate )
if self . it_rate and self . epoch_rate [ : - 7 ] != self . it_rate [ : - 4 ] :
self . metrics [ ' rate ' ] . append ( self . it_rate )
self . metrics [ ' rate ' ] = " , " . join ( self . metrics [ ' rate ' ] )
eta_hhmmss = self . eta_hhmmss if self . eta_hhmmss else " ? "
self . metrics [ ' loss ' ] = [ ]
if ' lr ' in self . info :
self . metrics [ ' loss ' ] . append ( f ' LR: { " {:.3e} " . format ( self . info [ " lr " ] ) } ' )
if len ( self . losses ) > 0 :
self . metrics [ ' loss ' ] . append ( f ' Loss: { " {:.3f} " . format ( self . losses [ - 1 ] [ " value " ] ) } ' )
2023-03-12 15:39:54 +00:00
if False and len ( self . losses ) > = 2 :
2023-03-12 15:17:07 +00:00
deriv = 0
accum_length = len ( self . losses ) / / 2 # i *guess* this is fine when you think about it
loss_value = self . losses [ - 1 ] [ " value " ]
for i in range ( accum_length ) :
d1_loss = self . losses [ accum_length - i - 1 ] [ " value " ]
d2_loss = self . losses [ accum_length - i - 2 ] [ " value " ]
dloss = ( d2_loss - d1_loss )
2023-03-12 15:39:54 +00:00
d1_step = self . losses [ accum_length - i - 1 ] [ " it " ]
d2_step = self . losses [ accum_length - i - 2 ] [ " it " ]
2023-03-12 15:17:07 +00:00
dstep = ( d2_step - d1_step )
if dstep == 0 :
continue
inst_deriv = dloss / dstep
deriv + = inst_deriv
deriv = deriv / accum_length
2023-03-12 15:39:54 +00:00
print ( " Deriv: " , deriv )
2023-03-12 15:17:07 +00:00
if deriv != 0 : # dloss < 0:
next_milestone = None
for milestone in self . loss_milestones :
if loss_value > milestone :
next_milestone = milestone
break
2023-03-12 15:39:54 +00:00
print ( f " Loss value: { loss_value } | Next milestone: { next_milestone } | Distance: { loss_value - next_milestone } " )
2023-03-12 15:17:07 +00:00
if next_milestone :
# tfw can do simple calculus but not basic algebra in my head
2023-03-12 15:39:54 +00:00
est_its = ( next_milestone - loss_value ) / deriv * 100
print ( f " Estimated: { est_its } " )
2023-03-12 15:17:07 +00:00
if est_its > = 0 :
self . metrics [ ' loss ' ] . append ( f ' Est. milestone { next_milestone } in: { int ( est_its ) } its ' )
else :
est_loss = inst_deriv * ( self . its - self . it ) + loss_value
if est_loss > = 0 :
self . metrics [ ' loss ' ] . append ( f ' Est. final loss: { " {:.3f} " . format ( est_loss ) } ' )
self . metrics [ ' loss ' ] = " , " . join ( self . metrics [ ' loss ' ] )
2023-03-12 15:39:54 +00:00
message = f " [ { self . metrics [ ' step ' ] } ] [ { self . metrics [ ' rate ' ] } ] [ETA: { eta_hhmmss } ] [ { self . metrics [ ' loss ' ] } ] "
2023-03-12 15:17:07 +00:00
if self . nan_detected :
message = f " [!NaN DETECTED! { self . nan_detected } ] { message } "
return message
2023-03-09 05:54:08 +00:00
def load_statistics ( self , update = False ) :
2023-03-22 17:47:23 +00:00
if not os . path . isdir ( self . training_dir ) :
2023-02-28 01:01:50 +00:00
return
2023-03-01 19:32:11 +00:00
2023-03-22 17:47:23 +00:00
if args . tts_backend == " tortoise " :
logs = sorted ( [ f ' { self . training_dir } /finetune/ { d } ' for d in os . listdir ( f ' { self . training_dir } /finetune/ ' ) if d [ - 4 : ] == " .log " ] )
else :
logs = sorted ( [ f ' { self . training_dir } /logs/ { d } /log.txt ' for d in os . listdir ( f ' { self . training_dir } /logs/ ' ) ] )
if update :
logs = [ logs [ - 1 ] ]
2023-03-02 00:46:52 +00:00
infos = { }
highest_step = self . last_info_check_at
2023-03-04 15:55:06 +00:00
if not update :
2023-03-09 05:54:08 +00:00
self . statistics [ ' loss ' ] = [ ]
self . statistics [ ' lr ' ] = [ ]
2023-03-23 15:42:51 +00:00
self . statistics [ ' grad_norm ' ] = [ ]
2023-03-12 15:17:07 +00:00
self . it_rates = 0
2023-03-04 15:55:06 +00:00
2023-03-25 02:34:14 +00:00
unq = { }
2023-03-31 03:26:00 +00:00
averager = None
2023-04-12 20:02:46 +00:00
prev_state = 0
2023-03-25 02:34:14 +00:00
2023-03-09 05:54:08 +00:00
for log in logs :
2023-03-10 22:35:32 +00:00
with open ( log , ' r ' , encoding = " utf-8 " ) as f :
lines = f . readlines ( )
2023-03-01 19:32:11 +00:00
2023-03-10 22:35:32 +00:00
for line in lines :
2023-03-23 15:42:51 +00:00
line = line . strip ( )
if not line :
continue
if line [ - 1 ] == " . " :
line = line [ : - 1 ]
2023-03-22 17:47:23 +00:00
if line . find ( ' Training Metrics: ' ) > = 0 :
2023-03-23 15:42:51 +00:00
split = line . split ( " Training Metrics: " ) [ - 1 ]
data = json . loads ( split )
2023-03-31 03:26:00 +00:00
2023-03-26 04:31:50 +00:00
name = " train "
2023-03-31 03:26:00 +00:00
mode = " training "
2023-04-12 20:02:46 +00:00
prev_state = 0
2023-03-22 17:47:23 +00:00
elif line . find ( ' Validation Metrics: ' ) > = 0 :
data = json . loads ( line . split ( " Validation Metrics: " ) [ - 1 ] )
2023-03-26 04:08:45 +00:00
if " it " not in data :
data [ ' it ' ] = it
2023-03-26 04:31:50 +00:00
if " epoch " not in data :
data [ ' epoch ' ] = epoch
2023-03-31 03:26:00 +00:00
2023-04-12 20:02:46 +00:00
# name = data['name'] if 'name' in data else "val"
2023-03-31 03:26:00 +00:00
mode = " validation "
2023-04-12 20:02:46 +00:00
if prev_state == 0 :
name = " subtrain "
else :
name = " val "
prev_state + = 1
2023-03-11 01:19:49 +00:00
else :
continue
2023-03-02 00:46:52 +00:00
2023-03-25 04:12:03 +00:00
if " it " not in data :
continue
2023-03-26 04:31:50 +00:00
2023-03-25 04:12:03 +00:00
it = data [ ' it ' ]
2023-03-26 04:31:50 +00:00
epoch = data [ ' epoch ' ]
2023-03-25 02:34:14 +00:00
2023-03-31 03:26:00 +00:00
if args . tts_backend == " vall-e " :
if not averager or averager [ ' key ' ] != f ' { it } _ { name } ' or averager [ ' mode ' ] != mode :
averager = {
' key ' : f ' { it } _ { name } ' ,
2023-04-12 20:02:46 +00:00
' name ' : name ,
2023-03-31 03:26:00 +00:00
' mode ' : mode ,
" metrics " : { }
}
for k in data :
if data [ k ] is None :
continue
averager [ ' metrics ' ] [ k ] = [ data [ k ] ]
else :
for k in data :
if data [ k ] is None :
continue
2023-05-03 21:31:37 +00:00
if k not in averager [ ' metrics ' ] :
averager [ ' metrics ' ] [ k ] = [ data [ k ] ]
else :
averager [ ' metrics ' ] [ k ] . append ( data [ k ] )
2023-03-31 03:26:00 +00:00
unq [ f ' { it } _ { mode } _ { name } ' ] = averager
else :
unq [ f ' { it } _ { mode } _ { name } ' ] = data
2023-03-10 22:35:32 +00:00
2023-03-12 15:17:07 +00:00
if update and it < = self . last_info_check_at :
2023-03-11 01:19:49 +00:00
continue
2023-03-25 02:34:14 +00:00
2023-04-12 20:02:46 +00:00
blacklist = [ " batch " , " eval " ]
2023-03-25 02:34:14 +00:00
for it in unq :
2023-03-31 03:26:00 +00:00
if args . tts_backend == " vall-e " :
stats = unq [ it ]
2023-04-12 20:02:46 +00:00
data = { k : sum ( v ) / len ( v ) for k , v in stats [ ' metrics ' ] . items ( ) if k not in blacklist }
2023-05-03 21:31:37 +00:00
#data = {k: min(v) for k, v in stats['metrics'].items() if k not in blacklist }
#data = {k: max(v) for k, v in stats['metrics'].items() if k not in blacklist }
2023-04-12 20:02:46 +00:00
data [ ' name ' ] = stats [ ' name ' ]
data [ ' mode ' ] = stats [ ' mode ' ]
2023-03-31 03:26:00 +00:00
data [ ' steps ' ] = len ( stats [ ' metrics ' ] [ ' it ' ] )
else :
data = unq [ it ]
self . parse_metrics ( data )
2023-03-02 00:46:52 +00:00
self . last_info_check_at = highest_step
2023-02-28 01:01:50 +00:00
def cleanup_old ( self , keep = 2 ) :
if keep < = 0 :
return
2023-03-01 01:17:38 +00:00
2023-03-22 17:47:23 +00:00
if args . tts_backend == " vall-e " :
return
if not os . path . isdir ( f ' { self . training_dir } /finetune/ ' ) :
2023-03-01 01:17:38 +00:00
return
2023-02-28 01:01:50 +00:00
2023-03-22 17:47:23 +00:00
models = sorted ( [ int ( d [ : - 8 ] ) for d in os . listdir ( f ' { self . training_dir } /finetune/models/ ' ) if d [ - 8 : ] == " _gpt.pth " ] )
states = sorted ( [ int ( d [ : - 6 ] ) for d in os . listdir ( f ' { self . training_dir } /finetune/training_state/ ' ) if d [ - 6 : ] == " .state " ] )
2023-03-12 06:01:08 +00:00
remove_models = models [ : - keep ]
remove_states = states [ : - keep ]
2023-02-28 01:01:50 +00:00
for d in remove_models :
2023-03-22 17:47:23 +00:00
path = f ' { self . training_dir } /finetune/models/ { d } _gpt.pth '
2023-02-28 01:01:50 +00:00
print ( " Removing " , path )
os . remove ( path )
for d in remove_states :
2023-03-22 17:47:23 +00:00
path = f ' { self . training_dir } /finetune/training_state/ { d } .state '
2023-02-28 01:01:50 +00:00
print ( " Removing " , path )
os . remove ( path )
2023-03-07 20:16:49 +00:00
def parse ( self , line , verbose = False , keep_x_past_checkpoints = 0 , buffer_size = 8 , progress = None ) :
2023-02-25 16:44:25 +00:00
self . buffer . append ( f ' { line } ' )
2023-02-19 05:05:30 +00:00
2023-03-17 05:33:49 +00:00
data = None
2023-03-04 17:37:08 +00:00
percent = 0
message = None
2023-03-17 05:33:49 +00:00
should_return = False
MESSAGE_START = ' Start training from epoch '
MESSAGE_FINSIHED = ' Finished training '
2023-03-22 17:47:23 +00:00
MESSAGE_SAVING = ' Saving models and training states. '
2023-03-17 05:33:49 +00:00
2023-03-22 17:47:23 +00:00
MESSAGE_METRICS_TRAINING = ' Training Metrics: '
MESSAGE_METRICS_VALIDATION = ' Validation Metrics: '
2023-03-17 05:33:49 +00:00
if line . find ( MESSAGE_FINSIHED ) > = 0 :
2023-03-10 22:35:32 +00:00
self . killed = True
2023-02-19 05:05:30 +00:00
# rip out iteration info
2023-03-10 22:35:32 +00:00
elif not self . training_started :
2023-03-17 05:33:49 +00:00
if line . find ( MESSAGE_START ) > = 0 :
2023-02-23 06:24:54 +00:00
self . training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
2023-03-02 00:46:52 +00:00
2023-02-23 15:31:43 +00:00
match = re . findall ( r ' epoch: ([ \ d,]+) ' , line )
if match and len ( match ) > 0 :
self . epoch = int ( match [ 0 ] . replace ( " , " , " " ) )
2023-02-22 01:17:09 +00:00
match = re . findall ( r ' iter: ([ \ d,]+) ' , line )
if match and len ( match ) > 0 :
2023-02-23 06:24:54 +00:00
self . it = int ( match [ 0 ] . replace ( " , " , " " ) )
2023-02-27 19:20:06 +00:00
2023-03-17 05:33:49 +00:00
self . checkpoints = int ( ( self . its - self . it ) / self . save_rate )
2023-03-11 01:19:49 +00:00
2023-03-12 15:17:07 +00:00
self . load_statistics ( )
2023-03-11 01:19:49 +00:00
should_return = True
2023-02-23 23:22:23 +00:00
else :
2023-03-17 05:33:49 +00:00
if line . find ( MESSAGE_SAVING ) > = 0 :
2023-03-15 02:48:05 +00:00
self . checkpoint + = 1
message = f " [ { self . checkpoint } / { self . checkpoints } ] Saving checkpoint... "
percent = self . checkpoint / self . checkpoints
self . cleanup_old ( keep = keep_x_past_checkpoints )
2023-03-17 05:33:49 +00:00
elif line . find ( MESSAGE_METRICS_TRAINING ) > = 0 :
data = json . loads ( line . split ( MESSAGE_METRICS_TRAINING ) [ - 1 ] )
2023-03-11 01:19:49 +00:00
data [ ' mode ' ] = " training "
2023-03-17 05:33:49 +00:00
elif line . find ( MESSAGE_METRICS_VALIDATION ) > = 0 :
data = json . loads ( line . split ( MESSAGE_METRICS_VALIDATION ) [ - 1 ] )
2023-03-11 01:19:49 +00:00
data [ ' mode ' ] = " validation "
2023-03-05 07:37:27 +00:00
2023-03-17 05:33:49 +00:00
if data is not None :
if ' : nan ' in line and not self . nan_detected :
self . nan_detected = self . it
self . parse_metrics ( data )
message = self . get_status ( )
if message :
percent = self . it / float ( self . its ) # self.epoch / float(self.epochs)
if progress is not None :
progress ( percent , message )
2023-03-05 14:39:24 +00:00
2023-03-17 05:33:49 +00:00
self . buffer . append ( f ' [ { " {:.3f} " . format ( percent * 100 ) } %] { message } ' )
should_return = True
2023-02-28 01:01:50 +00:00
2023-03-01 01:17:38 +00:00
if verbose and not self . training_started :
should_return = True
2023-02-25 16:44:25 +00:00
self . buffer = self . buffer [ - buffer_size : ]
2023-03-04 17:37:08 +00:00
result = None
2023-03-01 01:17:38 +00:00
if should_return :
2023-03-04 17:37:08 +00:00
result = " " . join ( self . buffer ) if not self . training_started else message
return (
result ,
percent ,
message ,
)
2023-02-19 07:05:11 +00:00
2023-03-08 00:51:51 +00:00
try :
import altair as alt
alt . data_transformers . enable ( ' default ' , max_rows = None )
except Exception as e :
print ( e )
pass
2023-03-09 00:26:47 +00:00
def run_training ( config_path , verbose = False , keep_x_past_checkpoints = 0 , progress = gr . Progress ( track_tqdm = True ) ) :
2023-02-23 06:24:54 +00:00
global training_state
if training_state and training_state . process :
return " Training already in progress "
2023-03-07 02:47:10 +00:00
2023-03-07 19:29:09 +00:00
2023-03-07 02:47:10 +00:00
# ensure we have the dvae.pth
2023-03-31 03:26:00 +00:00
if args . tts_backend == " tortoise " :
get_model_path ( ' dvae.pth ' )
2023-02-23 06:24:54 +00:00
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
torch . multiprocessing . freeze_support ( )
unload_tts ( )
unload_whisper ( )
unload_voicefixer ( )
2023-03-09 00:26:47 +00:00
training_state = TrainingState ( config_path = config_path , keep_x_past_checkpoints = keep_x_past_checkpoints )
2023-02-18 02:07:22 +00:00
2023-02-23 06:24:54 +00:00
for line in iter ( training_state . process . stdout . readline , " " ) :
2023-03-04 20:42:54 +00:00
if training_state . killed :
return
2023-03-07 20:16:49 +00:00
result , percent , message = training_state . parse ( line = line , verbose = verbose , keep_x_past_checkpoints = keep_x_past_checkpoints , progress = progress )
2023-02-23 23:22:23 +00:00
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
2023-03-04 17:37:08 +00:00
if result :
yield result
2023-02-23 06:24:54 +00:00
2023-03-05 05:17:19 +00:00
if progress is not None and message :
progress ( percent , message )
2023-02-24 23:13:13 +00:00
if training_state :
training_state . process . stdout . close ( )
return_code = training_state . process . wait ( )
training_state = None
2023-02-18 02:07:22 +00:00
2023-04-26 04:48:09 +00:00
def update_training_dataplot ( x_min = None , x_max = None , y_min = None , y_max = None , config_path = None ) :
2023-02-28 01:01:50 +00:00
global training_state
2023-03-09 05:54:08 +00:00
losses = None
lrs = None
2023-03-23 15:42:51 +00:00
grad_norms = None
2023-03-02 01:35:12 +00:00
2023-04-26 04:48:09 +00:00
x_lim = [ x_min , x_max ]
y_lim = [ y_min , y_max ]
2023-03-25 02:34:14 +00:00
2023-03-02 01:35:12 +00:00
if not training_state :
2023-03-03 04:37:18 +00:00
if config_path :
training_state = TrainingState ( config_path = config_path , start = False )
2023-03-12 15:17:07 +00:00
training_state . load_statistics ( )
message = training_state . get_status ( )
2023-03-25 02:34:14 +00:00
if training_state :
if not x_lim [ - 1 ] :
x_lim [ - 1 ] = training_state . epochs
if not y_lim [ - 1 ] :
y_lim = None
2023-03-09 05:54:08 +00:00
if len ( training_state . statistics [ ' loss ' ] ) > 0 :
2023-03-25 02:34:14 +00:00
losses = gr . LinePlot . update (
value = pd . DataFrame ( training_state . statistics [ ' loss ' ] ) ,
x_lim = x_lim , y_lim = y_lim ,
2023-05-11 03:30:54 +00:00
x = " epoch " , y = " value " , # x="it",
2023-03-25 02:34:14 +00:00
title = " Loss Metrics " , color = " type " , tooltip = [ ' epoch ' , ' it ' , ' value ' , ' type ' ] ,
width = 500 , height = 350
)
2023-03-09 05:54:08 +00:00
if len ( training_state . statistics [ ' lr ' ] ) > 0 :
2023-03-25 02:34:14 +00:00
lrs = gr . LinePlot . update (
value = pd . DataFrame ( training_state . statistics [ ' lr ' ] ) ,
2023-04-26 04:48:09 +00:00
x_lim = x_lim ,
2023-05-11 03:30:54 +00:00
x = " epoch " , y = " value " , # x="it",
2023-03-25 02:34:14 +00:00
title = " Learning Rate " , color = " type " , tooltip = [ ' epoch ' , ' it ' , ' value ' , ' type ' ] ,
width = 500 , height = 350
)
2023-03-23 15:42:51 +00:00
if len ( training_state . statistics [ ' grad_norm ' ] ) > 0 :
2023-03-25 02:34:14 +00:00
grad_norms = gr . LinePlot . update (
value = pd . DataFrame ( training_state . statistics [ ' grad_norm ' ] ) ,
2023-04-26 04:48:09 +00:00
x_lim = x_lim ,
2023-05-11 03:30:54 +00:00
x = " epoch " , y = " value " , # x="it",
2023-03-25 02:34:14 +00:00
title = " Gradient Normals " , color = " type " , tooltip = [ ' epoch ' , ' it ' , ' value ' , ' type ' ] ,
width = 500 , height = 350
)
if config_path :
del training_state
training_state = None
2023-03-02 01:35:12 +00:00
2023-03-23 15:42:51 +00:00
return ( losses , lrs , grad_norms )
2023-02-28 01:01:50 +00:00
2023-03-04 17:37:08 +00:00
def reconnect_training ( verbose = False , progress = gr . Progress ( track_tqdm = True ) ) :
2023-02-23 06:24:54 +00:00
global training_state
if not training_state or not training_state . process :
return " Training not in progress "
for line in iter ( training_state . process . stdout . readline , " " ) :
2023-03-07 19:29:09 +00:00
result , percent , message = training_state . parse ( line = line , verbose = verbose , progress = progress )
2023-03-04 17:37:08 +00:00
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
if result :
yield result
2023-02-19 05:05:30 +00:00
2023-03-05 05:17:19 +00:00
if progress is not None and message :
progress ( percent , message )
2023-02-18 02:07:22 +00:00
def stop_training ( ) :
2023-02-24 23:13:13 +00:00
global training_state
if training_state is None :
2023-02-18 02:07:22 +00:00
return " No training in progress "
2023-02-24 23:13:13 +00:00
print ( " Killing training process... " )
training_state . killed = True
2023-03-04 20:42:54 +00:00
children = [ ]
2023-03-14 16:23:29 +00:00
if args . tts_backend == " tortoise " :
# wrapped in a try/catch in case for some reason this fails outside of Linux
try :
children = [ p . info for p in psutil . process_iter ( attrs = [ ' pid ' , ' name ' , ' cmdline ' ] ) if ' ./src/train.py ' in p . info [ ' cmdline ' ] ]
except Exception as e :
pass
training_state . process . stdout . close ( )
training_state . process . terminate ( )
training_state . process . kill ( )
elif args . tts_backend == " vall-e " :
print ( training_state . process . communicate ( input = ' quit ' ) [ 0 ] )
2023-02-24 23:13:13 +00:00
return_code = training_state . process . wait ( )
2023-03-04 20:42:54 +00:00
for p in children :
os . kill ( p [ ' pid ' ] , signal . SIGKILL )
2023-02-24 23:13:13 +00:00
training_state = None
2023-03-04 17:37:08 +00:00
print ( " Killed training process. " )
2023-03-01 01:17:38 +00:00
return f " Training cancelled: { return_code } "
2023-02-17 19:06:05 +00:00
2023-02-21 21:50:05 +00:00
def get_halfp_model_path ( ) :
2023-02-21 19:31:57 +00:00
autoregressive_model_path = get_model_path ( ' autoregressive.pth ' )
return autoregressive_model_path . replace ( " .pth " , " _half.pth " )
2023-02-21 17:35:30 +00:00
def convert_to_halfp ( ) :
autoregressive_model_path = get_model_path ( ' autoregressive.pth ' )
2023-02-21 19:31:57 +00:00
print ( f ' Converting model to half precision: { autoregressive_model_path } ' )
2023-02-21 17:35:30 +00:00
model = torch . load ( autoregressive_model_path )
for k in model :
2023-02-21 19:31:57 +00:00
model [ k ] = model [ k ] . half ( )
2023-02-21 17:35:30 +00:00
2023-02-21 21:50:05 +00:00
outfile = get_halfp_model_path ( )
2023-02-21 19:31:57 +00:00
torch . save ( model , outfile )
print ( f ' Converted model to half precision: { outfile } ' )
2023-02-21 17:35:30 +00:00
2023-03-22 20:26:28 +00:00
# collapses short segments into the previous segment
def whisper_sanitize ( results ) :
2023-03-22 22:10:01 +00:00
sanitized = json . loads ( json . dumps ( results ) )
2023-03-22 20:26:28 +00:00
sanitized [ ' segments ' ] = [ ]
for segment in results [ ' segments ' ] :
length = segment [ ' end ' ] - segment [ ' start ' ]
if length > = MIN_TRAINING_DURATION or len ( sanitized [ ' segments ' ] ) == 0 :
sanitized [ ' segments ' ] . append ( segment )
continue
last_segment = sanitized [ ' segments ' ] [ - 1 ]
2023-03-22 22:10:01 +00:00
# segment already asimilitated it, somehow
if last_segment [ ' end ' ] > = segment [ ' end ' ] :
continue
"""
# segment already asimilitated it, somehow
if last_segment [ ' text ' ] . endswith ( segment [ ' text ' ] ) :
continue
"""
2023-03-22 20:26:28 +00:00
last_segment [ ' text ' ] + = segment [ ' text ' ]
last_segment [ ' end ' ] = segment [ ' end ' ]
2023-03-22 20:38:58 +00:00
for i in range ( len ( sanitized [ ' segments ' ] ) ) :
2023-03-22 22:10:01 +00:00
sanitized [ ' segments ' ] [ i ] [ ' id ' ] = i
2023-03-22 20:38:58 +00:00
2023-03-22 20:26:28 +00:00
return sanitized
2023-02-27 19:20:06 +00:00
def whisper_transcribe ( file , language = None ) :
# shouldn't happen, but it's for safety
2023-03-22 19:24:53 +00:00
global whisper_model
global whisper_vad
global whisper_diarize
2023-03-23 00:22:25 +00:00
global whisper_align_model
2023-02-27 19:20:06 +00:00
if not whisper_model :
2023-03-05 05:17:19 +00:00
load_whisper_model ( language = language )
2023-02-27 19:20:06 +00:00
2023-03-06 05:21:33 +00:00
if args . whisper_backend == " openai/whisper " :
2023-03-05 05:22:35 +00:00
if not language :
language = None
2023-03-05 06:45:07 +00:00
2023-03-05 05:17:19 +00:00
return whisper_model . transcribe ( file , language = language )
2023-02-27 19:20:06 +00:00
2023-03-11 16:40:34 +00:00
if args . whisper_backend == " lightmare/whispercpp " :
2023-03-06 05:21:33 +00:00
res = whisper_model . transcribe ( file )
segments = whisper_model . extract_text_and_timestamps ( res )
2023-02-27 19:20:06 +00:00
2023-03-06 05:21:33 +00:00
result = {
2023-03-22 20:26:28 +00:00
' text ' : [ ] ,
2023-03-06 05:21:33 +00:00
' segments ' : [ ]
2023-02-27 19:20:06 +00:00
}
2023-03-06 05:21:33 +00:00
for segment in segments :
reparsed = {
' start ' : segment [ 0 ] / 100.0 ,
' end ' : segment [ 1 ] / 100.0 ,
' text ' : segment [ 2 ] ,
2023-03-22 20:01:30 +00:00
' id ' : len ( result [ ' segments ' ] )
2023-03-06 05:21:33 +00:00
}
2023-03-22 20:01:30 +00:00
result [ ' text ' ] . append ( segment [ 2 ] )
2023-03-06 05:21:33 +00:00
result [ ' segments ' ] . append ( reparsed )
2023-03-22 20:01:30 +00:00
result [ ' text ' ] = " " . join ( result [ ' text ' ] )
2023-03-06 05:21:33 +00:00
return result
2023-02-27 19:20:06 +00:00
2023-03-22 19:24:53 +00:00
if args . whisper_backend == " m-bain/whisperx " :
import whisperx
from whisperx . diarize import assign_word_speakers
device = " cuda " if get_device_name ( ) == " cuda " else " cpu "
if whisper_vad :
2023-04-12 20:02:46 +00:00
# omits a considerable amount of the end
2023-03-22 19:24:53 +00:00
if args . whisper_batchsize > 1 :
2023-03-22 19:53:42 +00:00
result = whisperx . transcribe_with_vad_parallel ( whisper_model , file , whisper_vad , batch_size = args . whisper_batchsize , language = language , task = " transcribe " )
2023-03-22 19:24:53 +00:00
else :
result = whisperx . transcribe_with_vad ( whisper_model , file , whisper_vad )
2023-03-23 04:41:56 +00:00
"""
result = whisperx . transcribe_with_vad ( whisper_model , file , whisper_vad )
2023-04-26 04:48:09 +00:00
"""
2023-03-22 19:24:53 +00:00
else :
result = whisper_model . transcribe ( file )
2023-03-23 00:22:25 +00:00
align_model , metadata = whisper_align_model
2023-03-22 19:24:53 +00:00
result_aligned = whisperx . align ( result [ " segments " ] , align_model , metadata , file , device )
if whisper_diarize :
diarize_segments = whisper_diarize ( file )
diarize_df = pd . DataFrame ( diarize_segments . itertracks ( yield_label = True ) )
diarize_df [ ' start ' ] = diarize_df [ 0 ] . apply ( lambda x : x . start )
diarize_df [ ' end ' ] = diarize_df [ 0 ] . apply ( lambda x : x . end )
# assumes each utterance is single speaker (needs fix)
result_segments , word_segments = assign_word_speakers ( diarize_df , result_aligned [ " segments " ] , fill_nearest = True )
result_aligned [ " segments " ] = result_segments
result_aligned [ " word_segments " ] = word_segments
for i in range ( len ( result_aligned [ ' segments ' ] ) ) :
del result_aligned [ ' segments ' ] [ i ] [ ' word-segments ' ]
del result_aligned [ ' segments ' ] [ i ] [ ' char-segments ' ]
result [ ' segments ' ] = result_aligned [ ' segments ' ]
2023-03-22 19:53:42 +00:00
result [ ' text ' ] = [ ]
for segment in result [ ' segments ' ] :
2023-03-22 20:01:30 +00:00
segment [ ' id ' ] = len ( result [ ' text ' ] )
2023-03-22 19:53:42 +00:00
result [ ' text ' ] . append ( segment [ ' text ' ] . strip ( ) )
result [ ' text ' ] = " " . join ( result [ ' text ' ] )
2023-03-22 19:24:53 +00:00
return result
2023-03-13 04:26:00 +00:00
def validate_waveform ( waveform , sample_rate , min_only = False ) :
2023-03-11 17:27:01 +00:00
if not torch . any ( waveform < 0 ) :
2023-03-12 23:39:00 +00:00
return " Waveform is empty "
2023-03-11 17:27:01 +00:00
2023-03-12 23:39:00 +00:00
num_channels , num_frames = waveform . shape
2023-03-13 21:24:51 +00:00
duration = num_frames / sample_rate
2023-03-12 23:39:00 +00:00
2023-03-13 04:26:00 +00:00
if duration < MIN_TRAINING_DURATION :
return " Duration too short ( {:.3f} s < {:.3f} s) " . format ( duration , MIN_TRAINING_DURATION )
2023-03-12 23:39:00 +00:00
2023-03-13 04:26:00 +00:00
if not min_only :
if duration > MAX_TRAINING_DURATION :
return " Duration too long ( {:.3f} s < {:.3f} s) " . format ( MAX_TRAINING_DURATION , duration )
2023-03-12 23:39:00 +00:00
return
2023-03-11 17:27:01 +00:00
2023-03-11 21:17:11 +00:00
def transcribe_dataset ( voice , language = None , skip_existings = False , progress = None ) :
2023-02-20 00:21:16 +00:00
unload_tts ( )
2023-02-18 20:37:37 +00:00
2023-02-20 00:21:16 +00:00
global whisper_model
if whisper_model is None :
2023-02-27 19:20:06 +00:00
load_whisper_model ( language = language )
2023-02-17 19:06:05 +00:00
2023-02-20 00:21:16 +00:00
results = { }
2023-03-06 10:47:06 +00:00
2023-04-13 21:10:38 +00:00
files = get_voice ( voice , load_latents = False )
2023-03-11 21:17:11 +00:00
indir = f ' ./training/ { voice } / '
infile = f ' { indir } /whisper.json '
os . makedirs ( f ' { indir } /audio/ ' , exist_ok = True )
2023-03-23 01:52:26 +00:00
TARGET_SAMPLE_RATE = 22050
2023-04-26 04:48:09 +00:00
if args . tts_backend != " tortoise " :
2023-03-23 01:52:26 +00:00
TARGET_SAMPLE_RATE = 24000
if tts :
TARGET_SAMPLE_RATE = tts . input_sample_rate
2023-03-11 21:17:11 +00:00
if os . path . exists ( infile ) :
results = json . load ( open ( infile , ' r ' , encoding = " utf-8 " ) )
2023-03-11 16:32:35 +00:00
2023-05-04 23:40:33 +00:00
for file in tqdm ( files , desc = " Iterating through voice files " ) :
2023-03-05 17:54:36 +00:00
basename = os . path . basename ( file )
2023-03-06 10:47:06 +00:00
2023-03-11 21:17:11 +00:00
if basename in results and skip_existings :
2023-03-06 10:47:06 +00:00
print ( f " Skipping already parsed file: { basename } " )
2023-03-23 04:41:56 +00:00
continue
2023-02-18 20:37:37 +00:00
2023-03-22 20:38:58 +00:00
try :
2023-03-23 04:41:56 +00:00
result = whisper_transcribe ( file , language = language )
2023-03-22 20:38:58 +00:00
except Exception as e :
2023-04-26 04:48:09 +00:00
print ( " Failed to transcribe: " , file , e )
2023-03-23 04:41:56 +00:00
continue
2023-03-22 20:26:28 +00:00
2023-03-23 04:41:56 +00:00
results [ basename ] = result
2023-03-12 23:39:00 +00:00
waveform , sample_rate = torchaudio . load ( file )
2023-03-13 01:20:55 +00:00
# resample to the input rate, since it'll get resampled for training anyways
# this should also "help" increase throughput a bit when filling the dataloaders
2023-03-23 01:52:26 +00:00
waveform , sample_rate = resample ( waveform , sample_rate , TARGET_SAMPLE_RATE )
2023-03-17 01:24:02 +00:00
if waveform . shape [ 0 ] == 2 :
waveform = waveform [ : 1 ]
2023-03-16 04:25:33 +00:00
torchaudio . save ( f " { indir } /audio/ { basename } " , waveform , sample_rate , encoding = " PCM_S " , bits_per_sample = 16 )
2023-02-18 20:37:37 +00:00
2023-03-23 04:41:56 +00:00
with open ( infile , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( results , indent = ' \t ' ) )
2023-03-06 16:39:37 +00:00
do_gc ( )
2023-03-06 16:50:55 +00:00
2023-03-23 04:41:56 +00:00
modified = False
for basename in results :
try :
sanitized = whisper_sanitize ( results [ basename ] )
if len ( sanitized [ ' segments ' ] ) > 0 and len ( sanitized [ ' segments ' ] ) != len ( results [ basename ] [ ' segments ' ] ) :
results [ basename ] = sanitized
modified = True
print ( " Segments sanizited: " , basename )
except Exception as e :
print ( " Failed to sanitize: " , basename , e )
pass
if modified :
os . rename ( infile , infile . replace ( " .json " , " .unsanitized.json " ) )
with open ( infile , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( results , indent = ' \t ' ) )
2023-03-11 21:17:11 +00:00
return f " Processed dataset to: { indir } "
2023-03-07 05:43:26 +00:00
2023-03-13 04:26:00 +00:00
def slice_waveform ( waveform , sample_rate , start , end , trim ) :
start = int ( start * sample_rate )
end = int ( end * sample_rate )
if start < 0 :
start = 0
if end > = waveform . shape [ - 1 ] :
end = waveform . shape [ - 1 ] - 1
sliced = waveform [ : , start : end ]
error = validate_waveform ( sliced , sample_rate , min_only = True )
if trim and not error :
sliced = torchaudio . functional . vad ( sliced , sample_rate )
return sliced , error
2023-03-16 14:19:56 +00:00
def slice_dataset ( voice , trim_silence = True , start_offset = 0 , end_offset = 0 , results = None , progress = gr . Progress ( ) ) :
2023-03-08 02:58:00 +00:00
indir = f ' ./training/ { voice } / '
2023-03-11 21:17:11 +00:00
infile = f ' { indir } /whisper.json '
2023-03-12 23:39:00 +00:00
messages = [ ]
2023-03-11 21:17:11 +00:00
2023-03-08 02:58:00 +00:00
if not os . path . exists ( infile ) :
2023-04-12 20:02:46 +00:00
message = f " Missing dataset: { infile } "
print ( message )
return message
2023-03-11 21:17:11 +00:00
2023-03-13 04:26:00 +00:00
if results is None :
results = json . load ( open ( infile , ' r ' , encoding = " utf-8 " ) )
2023-03-11 21:17:11 +00:00
2023-03-23 01:52:26 +00:00
TARGET_SAMPLE_RATE = 22050
2023-04-26 04:48:09 +00:00
if args . tts_backend != " tortoise " :
2023-03-23 01:52:26 +00:00
TARGET_SAMPLE_RATE = 24000
if tts :
TARGET_SAMPLE_RATE = tts . input_sample_rate
2023-03-11 21:17:11 +00:00
files = 0
segments = 0
for filename in results :
2023-03-12 23:39:00 +00:00
path = f ' ./voices/ { voice } / { filename } '
if not os . path . exists ( path ) :
path = f ' ./training/ { voice } / { filename } '
2023-03-11 21:17:11 +00:00
2023-03-12 23:39:00 +00:00
if not os . path . exists ( path ) :
2023-03-13 04:26:00 +00:00
message = f " Missing source audio: { filename } "
print ( message )
messages . append ( message )
2023-03-12 23:39:00 +00:00
continue
files + = 1
2023-03-11 21:17:11 +00:00
result = results [ filename ]
2023-03-12 23:39:00 +00:00
waveform , sample_rate = torchaudio . load ( path )
2023-03-13 04:26:00 +00:00
num_channels , num_frames = waveform . shape
2023-03-13 21:24:51 +00:00
duration = num_frames / sample_rate
2023-03-11 21:17:11 +00:00
2023-03-12 23:39:00 +00:00
for segment in result [ ' segments ' ] :
2023-03-11 21:17:11 +00:00
file = filename . replace ( " .wav " , f " _ { pad ( segment [ ' id ' ] , 4 ) } .wav " )
2023-03-13 01:20:55 +00:00
2023-03-13 04:26:00 +00:00
sliced , error = slice_waveform ( waveform , sample_rate , segment [ ' start ' ] + start_offset , segment [ ' end ' ] + end_offset , trim_silence )
if error :
message = f " { error } , skipping... { file } "
print ( message )
messages . append ( message )
continue
2023-03-23 01:52:26 +00:00
sliced , _ = resample ( sliced , sample_rate , TARGET_SAMPLE_RATE )
2023-03-17 01:24:02 +00:00
if waveform . shape [ 0 ] == 2 :
waveform = waveform [ : 1 ]
2023-03-23 01:52:26 +00:00
torchaudio . save ( f " { indir } /audio/ { file } " , sliced , TARGET_SAMPLE_RATE , encoding = " PCM_S " , bits_per_sample = 16 )
2023-03-13 01:20:55 +00:00
segments + = 1
2023-03-11 21:17:11 +00:00
2023-03-12 23:39:00 +00:00
messages . append ( f " Sliced segments: { files } => { segments } . " )
return " \n " . join ( messages )
2023-03-11 21:17:11 +00:00
2023-03-17 01:24:02 +00:00
# takes an LJSpeech-dataset-formatted .txt file and phonemize it
def phonemize_txt_file ( path ) :
2023-03-16 20:48:48 +00:00
with open ( path , ' r ' , encoding = ' utf-8 ' ) as f :
lines = f . readlines ( )
reparsed = [ ]
with open ( path . replace ( " .txt " , " .phn.txt " ) , ' a ' , encoding = ' utf-8 ' ) as f :
2023-05-04 23:40:33 +00:00
for line in tqdm ( lines , desc = ' Phonemizing... ' ) :
2023-03-16 20:48:48 +00:00
split = line . split ( " | " )
audio = split [ 0 ]
text = split [ 2 ]
2023-03-17 05:33:49 +00:00
phonemes = phonemizer ( text )
2023-03-16 20:48:48 +00:00
reparsed . append ( f ' { audio } | { phonemes } ' )
f . write ( f ' \n { audio } | { phonemes } ' )
joined = " \n " . join ( reparsed )
with open ( path . replace ( " .txt " , " .phn.txt " ) , ' w ' , encoding = ' utf-8 ' ) as f :
f . write ( joined )
return joined
2023-03-17 01:24:02 +00:00
# takes an LJSpeech-dataset-formatted .txt (and phonemized .phn.txt from the above) and creates a JSON that should slot in as whisper.json
def create_dataset_json ( path ) :
with open ( path , ' r ' , encoding = ' utf-8 ' ) as f :
lines = f . readlines ( )
phonemes = None
phn_path = path . replace ( " .txt " , " .phn.txt " )
if os . path . exists ( phn_path ) :
with open ( phn_path , ' r ' , encoding = ' utf-8 ' ) as f :
phonemes = f . readlines ( )
data = { }
for line in lines :
split = line . split ( " | " )
audio = split [ 0 ]
text = split [ 1 ]
data [ audio ] = {
' text ' : text . strip ( )
}
for line in phonemes :
split = line . split ( " | " )
audio = split [ 0 ]
text = split [ 1 ]
data [ audio ] [ ' phonemes ' ] = text . strip ( )
with open ( path . replace ( " .txt " , " .json " ) , ' w ' , encoding = ' utf-8 ' ) as f :
f . write ( json . dumps ( data , indent = " \t " ) )
2023-04-28 15:31:45 +00:00
cached_backends = { }
2023-03-17 05:33:49 +00:00
def phonemizer ( text , language = " en-us " ) :
from phonemizer import phonemize
2023-04-28 15:31:45 +00:00
from phonemizer . backend import BACKENDS
def _get_backend ( language = " en-us " , backend = " espeak " ) :
2023-04-28 15:56:57 +00:00
key = f ' { language } _ { backend } '
if key in cached_backends :
return cached_backends [ key ]
if backend == ' espeak ' :
phonemizer = BACKENDS [ backend ] ( language , preserve_punctuation = True , with_stress = True )
elif backend == ' espeak-mbrola ' :
phonemizer = BACKENDS [ backend ] ( language )
else :
phonemizer = BACKENDS [ backend ] ( language , preserve_punctuation = True )
cached_backends [ key ] = phonemizer
return phonemizer
2023-03-18 15:14:22 +00:00
if language == " en " :
2023-03-17 05:33:49 +00:00
language = " en-us "
2023-04-28 15:31:45 +00:00
backend = _get_backend ( language = language , backend = args . phonemizer_backend )
2023-04-28 15:56:57 +00:00
if backend is not None :
2023-05-12 17:41:26 +00:00
tokens = backend . phonemize ( [ text ] , strip = True )
2023-04-28 15:56:57 +00:00
else :
2023-05-12 17:41:26 +00:00
tokens = phonemize ( [ text ] , language = language , strip = True , preserve_punctuation = True , with_stress = True )
2023-04-28 15:31:45 +00:00
2023-04-28 15:56:57 +00:00
return tokens [ 0 ] if len ( tokens ) == 0 else tokens
tokenized = " " . join ( tokens )
2023-03-17 05:33:49 +00:00
def should_phonemize ( ) :
2023-03-17 20:08:08 +00:00
should = args . tokenizer_json is not None and args . tokenizer_json [ - 8 : ] == " ipa.json "
if should :
try :
from phonemizer import phonemize
except Exception as e :
return False
return should
2023-03-17 05:33:49 +00:00
2023-03-16 14:41:40 +00:00
def prepare_dataset ( voice , use_segments = False , text_length = 0 , audio_length = 0 , progress = gr . Progress ( ) ) :
2023-03-11 21:17:11 +00:00
indir = f ' ./training/ { voice } / '
infile = f ' { indir } /whisper.json '
2023-03-08 02:58:00 +00:00
if not os . path . exists ( infile ) :
2023-04-12 20:02:46 +00:00
message = f " Missing dataset: { infile } "
print ( message )
return message
2023-03-08 02:58:00 +00:00
2023-03-11 21:17:11 +00:00
results = json . load ( open ( infile , ' r ' , encoding = " utf-8 " ) )
2023-03-17 01:24:02 +00:00
errored = 0
messages = [ ]
normalize = True
2023-03-17 05:33:49 +00:00
phonemize = should_phonemize ( )
2023-03-17 01:24:02 +00:00
lines = { ' training ' : [ ] , ' validation ' : [ ] }
segments = { }
2023-03-08 02:58:00 +00:00
2023-04-26 04:48:09 +00:00
if args . tts_backend != " tortoise " :
2023-03-23 04:53:31 +00:00
text_length = 0
audio_length = 0
2023-05-04 23:40:33 +00:00
for filename in tqdm ( results , desc = " Parsing results " ) :
2023-03-13 04:26:00 +00:00
use_segment = use_segments
2023-03-17 01:24:02 +00:00
2023-03-16 04:25:33 +00:00
result = results [ filename ]
2023-03-18 15:14:22 +00:00
lang = result [ ' language ' ]
language = LANGUAGES [ lang ] if lang in LANGUAGES else lang
2023-03-17 01:24:02 +00:00
normalizer = EnglishTextNormalizer ( ) if language and language == " english " else BasicTextNormalizer ( )
2023-03-13 04:26:00 +00:00
2023-03-13 18:51:53 +00:00
# check if unsegmented text exceeds 200 characters
if not use_segment :
if len ( result [ ' text ' ] ) > 200 :
2023-03-13 19:07:23 +00:00
message = f " Text length too long (200 < { len ( result [ ' text ' ] ) } ), using segments: { filename } "
2023-03-13 18:51:53 +00:00
print ( message )
messages . append ( message )
use_segment = True
2023-03-13 04:26:00 +00:00
# check if unsegmented audio exceeds 11.6s
if not use_segment :
path = f ' { indir } /audio/ { filename } '
if not os . path . exists ( path ) :
messages . append ( f " Missing source audio: { filename } " )
2023-03-13 19:07:23 +00:00
errored + = 1
2023-03-13 04:26:00 +00:00
continue
metadata = torchaudio . info ( path )
2023-03-13 21:24:51 +00:00
duration = metadata . num_frames / metadata . sample_rate
2023-03-13 04:26:00 +00:00
if duration > = MAX_TRAINING_DURATION :
message = f " Audio too large, using segments: { filename } "
print ( message )
messages . append ( message )
use_segment = True
2023-03-17 01:24:02 +00:00
# implicitly segment
if use_segment and not use_segments :
2023-03-17 02:08:07 +00:00
exists = True
for segment in result [ ' segments ' ] :
2023-03-17 05:33:49 +00:00
duration = segment [ ' end ' ] - segment [ ' start ' ]
if duration < = MIN_TRAINING_DURATION or MAX_TRAINING_DURATION < = duration :
continue
path = f ' { indir } /audio/ ' + filename . replace ( " .wav " , f " _ { pad ( segment [ ' id ' ] , 4 ) } .wav " )
if os . path . exists ( path ) :
2023-03-17 02:08:07 +00:00
continue
exists = False
break
if not exists :
tmp = { }
tmp [ filename ] = result
print ( f " Audio not segmented, segmenting: { filename } " )
message = slice_dataset ( voice , results = tmp )
print ( message )
messages = messages + message . split ( " \n " )
2023-03-16 04:25:33 +00:00
2023-03-17 01:24:02 +00:00
if not use_segment :
segments [ filename ] = {
' text ' : result [ ' text ' ] ,
2023-03-18 15:14:22 +00:00
' lang ' : lang ,
2023-03-17 01:24:02 +00:00
' language ' : language ,
' normalizer ' : normalizer ,
' phonemes ' : result [ ' phonemes ' ] if ' phonemes ' in result else None
}
else :
for segment in result [ ' segments ' ] :
2023-03-17 05:33:49 +00:00
duration = segment [ ' end ' ] - segment [ ' start ' ]
if duration < = MIN_TRAINING_DURATION or MAX_TRAINING_DURATION < = duration :
continue
2023-03-17 01:24:02 +00:00
segments [ filename . replace ( " .wav " , f " _ { pad ( segment [ ' id ' ] , 4 ) } .wav " ) ] = {
' text ' : segment [ ' text ' ] ,
2023-03-18 15:14:22 +00:00
' lang ' : lang ,
2023-03-17 01:24:02 +00:00
' language ' : language ,
' normalizer ' : normalizer ,
' phonemes ' : segment [ ' phonemes ' ] if ' phonemes ' in segment else None
}
2023-03-13 04:26:00 +00:00
2023-03-23 00:22:25 +00:00
jobs = {
' quantize ' : [ [ ] , [ ] ] ,
' phonemize ' : [ [ ] , [ ] ] ,
}
2023-05-04 23:40:33 +00:00
for file in tqdm ( segments , desc = " Parsing segments " ) :
2023-03-17 01:24:02 +00:00
result = segments [ file ]
path = f ' { indir } /audio/ { file } '
2023-03-22 17:47:23 +00:00
if not os . path . exists ( path ) :
message = f " Missing segment, skipping... { file } "
print ( message )
messages . append ( message )
errored + = 1
continue
2023-03-17 01:24:02 +00:00
text = result [ ' text ' ]
2023-03-18 15:14:22 +00:00
lang = result [ ' lang ' ]
2023-03-17 01:24:02 +00:00
language = result [ ' language ' ]
normalizer = result [ ' normalizer ' ]
phonemes = result [ ' phonemes ' ]
if phonemize and phonemes is None :
2023-03-18 15:14:22 +00:00
phonemes = phonemizer ( text , language = lang )
2023-03-23 00:22:25 +00:00
2023-03-22 17:47:23 +00:00
normalized = normalizer ( text ) if normalize else text
2023-03-17 01:24:02 +00:00
if len ( text ) > 200 :
message = f " Text length too long (200 < { len ( text ) } ), skipping... { file } "
print ( message )
messages . append ( message )
errored + = 1
continue
2023-03-12 23:39:00 +00:00
2023-03-17 01:24:02 +00:00
waveform , sample_rate = torchaudio . load ( path )
num_channels , num_frames = waveform . shape
duration = num_frames / sample_rate
2023-03-12 23:39:00 +00:00
2023-03-17 01:24:02 +00:00
error = validate_waveform ( waveform , sample_rate )
if error :
message = f " { error } , skipping... { file } "
print ( message )
messages . append ( message )
errored + = 1
continue
2023-03-08 02:58:00 +00:00
2023-03-17 01:24:02 +00:00
culled = len ( text ) < text_length
if not culled and audio_length > 0 :
culled = duration < audio_length
2023-03-08 02:58:00 +00:00
2023-03-23 00:22:25 +00:00
line = f ' audio/ { file } | { phonemes if phonemize and phonemes else text } '
2023-03-13 19:07:23 +00:00
2023-03-17 01:24:02 +00:00
lines [ ' training ' if not culled else ' validation ' ] . append ( line )
2023-03-11 16:32:35 +00:00
2023-03-17 01:24:02 +00:00
if culled or args . tts_backend != " vall-e " :
continue
os . makedirs ( f ' { indir } /valle/ ' , exist_ok = True )
2023-03-14 05:02:14 +00:00
2023-03-23 00:22:25 +00:00
qnt_file = f ' { indir } /valle/ { file . replace ( " .wav " , " .qnt.pt " ) } '
if not os . path . exists ( qnt_file ) :
jobs [ ' quantize ' ] [ 0 ] . append ( qnt_file )
jobs [ ' quantize ' ] [ 1 ] . append ( ( waveform , sample_rate ) )
"""
quantized = valle_quantize ( waveform , sample_rate ) . cpu ( )
2023-03-22 17:47:23 +00:00
torch . save ( quantized , f ' { indir } /valle/ { file . replace ( " .wav " , " .qnt.pt " ) } ' )
print ( " Quantized: " , file )
2023-03-23 00:22:25 +00:00
"""
phn_file = f ' { indir } /valle/ { file . replace ( " .wav " , " .phn.txt " ) } '
if not os . path . exists ( phn_file ) :
jobs [ ' phonemize ' ] [ 0 ] . append ( phn_file )
jobs [ ' phonemize ' ] [ 1 ] . append ( normalized )
"""
phonemized = valle_phonemize ( normalized )
2023-03-22 17:47:23 +00:00
open ( f ' { indir } /valle/ { file . replace ( " .wav " , " .phn.txt " ) } ' , ' w ' , encoding = ' utf-8 ' ) . write ( " " . join ( phonemized ) )
2023-03-23 00:22:25 +00:00
print ( " Phonemized: " , file , normalized , text )
"""
2023-05-04 23:40:33 +00:00
for i in tqdm ( range ( len ( jobs [ ' quantize ' ] [ 0 ] ) ) , desc = " Quantizing " ) :
2023-03-23 00:22:25 +00:00
qnt_file = jobs [ ' quantize ' ] [ 0 ] [ i ]
waveform , sample_rate = jobs [ ' quantize ' ] [ 1 ] [ i ]
quantized = valle_quantize ( waveform , sample_rate ) . cpu ( )
torch . save ( quantized , qnt_file )
2023-03-23 04:53:31 +00:00
print ( " Quantized: " , qnt_file )
2023-03-23 00:22:25 +00:00
2023-05-04 23:40:33 +00:00
for i in tqdm ( range ( len ( jobs [ ' phonemize ' ] [ 0 ] ) ) , desc = " Phonemizing " ) :
2023-03-23 00:22:25 +00:00
phn_file = jobs [ ' phonemize ' ] [ 0 ] [ i ]
normalized = jobs [ ' phonemize ' ] [ 1 ] [ i ]
2023-04-13 21:10:38 +00:00
try :
phonemized = valle_phonemize ( normalized )
open ( phn_file , ' w ' , encoding = ' utf-8 ' ) . write ( " " . join ( phonemized ) )
print ( " Phonemized: " , phn_file )
except Exception as e :
message = f " Failed to phonemize: { phn_file } : { normalized } "
messages . append ( message )
print ( message )
2023-03-23 00:22:25 +00:00
2023-03-11 21:17:11 +00:00
training_joined = " \n " . join ( lines [ ' training ' ] )
validation_joined = " \n " . join ( lines [ ' validation ' ] )
2023-03-08 02:58:00 +00:00
with open ( f ' { indir } /train.txt ' , ' w ' , encoding = " utf-8 " ) as f :
2023-03-11 21:17:11 +00:00
f . write ( training_joined )
2023-03-08 02:58:00 +00:00
with open ( f ' { indir } /validation.txt ' , ' w ' , encoding = " utf-8 " ) as f :
2023-03-11 21:17:11 +00:00
f . write ( validation_joined )
2023-03-08 02:58:00 +00:00
2023-03-13 19:07:23 +00:00
messages . append ( f " Prepared { len ( lines [ ' training ' ] ) } lines (validation: { len ( lines [ ' validation ' ] ) } , culled: { errored } ). \n { training_joined } \n \n { validation_joined } " )
2023-03-12 23:39:00 +00:00
return " \n " . join ( messages )
2023-03-08 02:58:00 +00:00
2023-02-19 20:22:03 +00:00
def calc_iterations ( epochs , lines , batch_size ) :
2023-03-17 05:33:49 +00:00
return int ( math . ceil ( epochs * math . ceil ( lines / batch_size ) ) )
2023-02-19 20:22:03 +00:00
2023-03-09 14:17:01 +00:00
def schedule_learning_rate ( iterations , schedule = LEARNING_RATE_SCHEDULE ) :
2023-03-04 04:41:56 +00:00
return [ int ( iterations * d ) for d in schedule ]
2023-02-19 20:22:03 +00:00
2023-03-09 00:26:47 +00:00
def optimize_training_settings ( * * kwargs ) :
messages = [ ]
settings = { }
settings . update ( kwargs )
2023-02-19 20:22:03 +00:00
2023-03-09 00:26:47 +00:00
dataset_path = f " ./training/ { settings [ ' voice ' ] } /train.txt "
2023-02-19 20:22:03 +00:00
with open ( dataset_path , ' r ' , encoding = " utf-8 " ) as f :
lines = len ( f . readlines ( ) )
2023-03-12 23:39:00 +00:00
if lines == 0 :
raise Exception ( " Empty dataset. " )
2023-03-09 00:26:47 +00:00
if settings [ ' batch_size ' ] > lines :
settings [ ' batch_size ' ] = lines
messages . append ( f " Batch size is larger than your dataset, clamping batch size to: { settings [ ' batch_size ' ] } " )
2023-02-19 20:22:03 +00:00
2023-03-09 03:26:18 +00:00
"""
2023-03-09 02:08:06 +00:00
if lines % settings [ ' batch_size ' ] != 0 :
settings [ ' batch_size ' ] = int ( lines / settings [ ' batch_size ' ] )
if settings [ ' batch_size ' ] == 0 :
settings [ ' batch_size ' ] = 1
messages . append ( f " Batch size not neatly divisible by dataset size, adjusting batch size to: { settings [ ' batch_size ' ] } " )
2023-03-09 03:26:18 +00:00
"""
2023-03-09 00:26:47 +00:00
if settings [ ' gradient_accumulation_size ' ] == 0 :
settings [ ' gradient_accumulation_size ' ] = 1
2023-03-04 17:37:08 +00:00
2023-03-09 00:26:47 +00:00
if settings [ ' batch_size ' ] / settings [ ' gradient_accumulation_size ' ] < 2 :
settings [ ' gradient_accumulation_size ' ] = int ( settings [ ' batch_size ' ] / 2 )
if settings [ ' gradient_accumulation_size ' ] == 0 :
settings [ ' gradient_accumulation_size ' ] = 1
2023-03-04 17:37:08 +00:00
2023-03-09 00:26:47 +00:00
messages . append ( f " Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: { settings [ ' gradient_accumulation_size ' ] } " )
elif settings [ ' batch_size ' ] % settings [ ' gradient_accumulation_size ' ] != 0 :
2023-03-11 01:37:00 +00:00
settings [ ' gradient_accumulation_size ' ] - = settings [ ' batch_size ' ] % settings [ ' gradient_accumulation_size ' ]
2023-03-09 00:26:47 +00:00
if settings [ ' gradient_accumulation_size ' ] == 0 :
settings [ ' gradient_accumulation_size ' ] = 1
2023-03-04 17:37:08 +00:00
2023-03-09 00:26:47 +00:00
messages . append ( f " Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: { settings [ ' gradient_accumulation_size ' ] } " )
2023-02-19 21:06:14 +00:00
2023-03-09 02:08:06 +00:00
if settings [ ' batch_size ' ] % settings [ ' gpus ' ] != 0 :
2023-03-11 01:37:00 +00:00
settings [ ' batch_size ' ] - = settings [ ' batch_size ' ] % settings [ ' gpus ' ]
2023-03-09 02:08:06 +00:00
if settings [ ' batch_size ' ] == 0 :
settings [ ' batch_size ' ] = 1
messages . append ( f " Batch size not neatly divisible by GPU count, adjusting batch size to: { settings [ ' batch_size ' ] } " )
def get_device_batch_size ( vram ) :
DEVICE_BATCH_SIZE_MAP = [
2023-03-13 18:51:53 +00:00
( 70 , 128 ) , # based on an A100-80G, I can safely get a ratio of 4096:32 = 128
( 32 , 64 ) , # based on my two 6800XTs, I can only really safely get a ratio of 128:2 = 64
2023-03-09 02:08:06 +00:00
( 16 , 8 ) , # based on an A4000, I can do a ratio of 512:64 = 8:1
( 8 , 4 ) , # interpolated
( 6 , 2 ) , # based on my 2060, it only really lets me have a batch ratio of 2:1
]
for k , v in DEVICE_BATCH_SIZE_MAP :
if vram > ( k - 1 ) :
return v
return 1
2023-03-09 18:34:52 +00:00
if settings [ ' gpus ' ] > get_device_count ( ) :
settings [ ' gpus ' ] = get_device_count ( )
messages . append ( f " GPU count exceeds defacto GPU count, clamping to: { settings [ ' gpus ' ] } " )
if settings [ ' gpus ' ] < = 1 :
settings [ ' gpus ' ] = 1
else :
messages . append ( f " ! EXPERIMENTAL ! Multi-GPU training is extremely particular, expect issues. " )
2023-03-11 01:37:00 +00:00
2023-03-09 02:08:06 +00:00
# assuming you have equal GPUs
vram = get_device_vram ( ) * settings [ ' gpus ' ]
batch_ratio = int ( settings [ ' batch_size ' ] / settings [ ' gradient_accumulation_size ' ] )
batch_cap = get_device_batch_size ( vram )
if batch_ratio > batch_cap :
settings [ ' gradient_accumulation_size ' ] = int ( settings [ ' batch_size ' ] / batch_cap )
messages . append ( f " Batch ratio ( { batch_ratio } ) is expected to exceed your VRAM capacity ( { ' {:.3f} ' . format ( vram ) } GB, suggested { batch_cap } batch size cap), adjusting gradient accumulation size to: { settings [ ' gradient_accumulation_size ' ] } " )
2023-03-09 00:53:00 +00:00
2023-03-09 00:26:47 +00:00
iterations = calc_iterations ( epochs = settings [ ' epochs ' ] , lines = lines , batch_size = settings [ ' batch_size ' ] )
2023-02-19 20:22:03 +00:00
2023-03-09 00:26:47 +00:00
if settings [ ' epochs ' ] < settings [ ' save_rate ' ] :
settings [ ' save_rate ' ] = settings [ ' epochs ' ]
messages . append ( f " Save rate is too small for the given iteration step, clamping save rate to: { settings [ ' save_rate ' ] } " )
2023-02-19 20:22:03 +00:00
2023-03-09 00:26:47 +00:00
if settings [ ' epochs ' ] < settings [ ' validation_rate ' ] :
settings [ ' validation_rate ' ] = settings [ ' epochs ' ]
messages . append ( f " Validation rate is too small for the given iteration step, clamping validation rate to: { settings [ ' validation_rate ' ] } " )
2023-03-08 02:58:00 +00:00
2023-03-09 00:26:47 +00:00
if settings [ ' resume_state ' ] and not os . path . exists ( settings [ ' resume_state ' ] ) :
settings [ ' resume_state ' ] = None
2023-02-19 20:22:03 +00:00
messages . append ( " Resume path specified, but does not exist. Disabling... " )
2023-03-09 00:26:47 +00:00
if settings [ ' bitsandbytes ' ] :
2023-03-09 18:34:52 +00:00
messages . append ( " ! EXPERIMENTAL ! BitsAndBytes requested. " )
2023-02-21 19:31:57 +00:00
2023-03-09 00:26:47 +00:00
if settings [ ' half_p ' ] :
if settings [ ' bitsandbytes ' ] :
settings [ ' half_p ' ] = False
2023-03-04 15:55:06 +00:00
messages . append ( " Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision... " )
else :
2023-03-09 18:34:52 +00:00
messages . append ( " ! EXPERIMENTAL ! Half Precision requested. " )
2023-03-04 15:55:06 +00:00
if not os . path . exists ( get_halfp_model_path ( ) ) :
2023-03-09 02:08:06 +00:00
convert_to_halfp ( )
2023-03-04 15:55:06 +00:00
2023-03-17 18:57:36 +00:00
steps = int ( iterations / settings [ ' epochs ' ] )
2023-03-17 05:33:49 +00:00
2023-03-17 18:57:36 +00:00
messages . append ( f " For { settings [ ' epochs ' ] } epochs with { lines } lines in batches of { settings [ ' batch_size ' ] } , iterating for { iterations } steps ( { steps } ) steps per epoch) " )
2023-02-19 20:22:03 +00:00
2023-03-09 00:26:47 +00:00
return settings , messages
2023-02-19 20:22:03 +00:00
2023-03-09 00:26:47 +00:00
def save_training_settings ( * * kwargs ) :
messages = [ ]
settings = { }
settings . update ( kwargs )
2023-03-17 05:33:49 +00:00
2023-03-01 01:17:38 +00:00
2023-03-09 02:27:20 +00:00
outjson = f ' ./training/ { settings [ " voice " ] } /train.json '
with open ( outjson , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( settings , indent = ' \t ' ) )
2023-03-09 00:26:47 +00:00
settings [ ' dataset_path ' ] = f " ./training/ { settings [ ' voice ' ] } /train.txt "
settings [ ' validation_path ' ] = f " ./training/ { settings [ ' voice ' ] } /validation.txt "
with open ( settings [ ' dataset_path ' ] , ' r ' , encoding = " utf-8 " ) as f :
lines = len ( f . readlines ( ) )
2023-03-17 05:33:49 +00:00
settings [ ' iterations ' ] = calc_iterations ( epochs = settings [ ' epochs ' ] , lines = lines , batch_size = settings [ ' batch_size ' ] )
2023-03-09 00:26:47 +00:00
if not settings [ ' source_model ' ] or settings [ ' source_model ' ] == " auto " :
settings [ ' source_model ' ] = f " ./models/tortoise/autoregressive { ' _half ' if settings [ ' half_p ' ] else ' ' } .pth "
if settings [ ' half_p ' ] :
if not os . path . exists ( get_halfp_model_path ( ) ) :
convert_to_halfp ( )
messages . append ( f " For { settings [ ' epochs ' ] } epochs with { lines } lines, iterating for { settings [ ' iterations ' ] } steps " )
2023-03-09 18:34:52 +00:00
iterations_per_epoch = settings [ ' iterations ' ] / settings [ ' epochs ' ]
2023-03-09 05:54:08 +00:00
settings [ ' save_rate ' ] = int ( settings [ ' save_rate ' ] * iterations_per_epoch )
settings [ ' validation_rate ' ] = int ( settings [ ' validation_rate ' ] * iterations_per_epoch )
2023-03-09 00:26:47 +00:00
2023-03-09 18:34:52 +00:00
iterations_per_epoch = int ( iterations_per_epoch )
if settings [ ' save_rate ' ] < 1 :
settings [ ' save_rate ' ] = 1
2023-03-16 04:25:33 +00:00
"""
2023-03-09 18:34:52 +00:00
if settings [ ' validation_rate ' ] < 1 :
settings [ ' validation_rate ' ] = 1
2023-03-16 04:25:33 +00:00
"""
2023-03-17 05:33:49 +00:00
"""
2023-03-09 00:26:47 +00:00
if settings [ ' iterations ' ] % settings [ ' save_rate ' ] != 0 :
adjustment = int ( settings [ ' iterations ' ] / settings [ ' save_rate ' ] ) * settings [ ' save_rate ' ]
messages . append ( f " Iteration rate is not evenly divisible by save rate, adjusting: { settings [ ' iterations ' ] } => { adjustment } " )
settings [ ' iterations ' ] = adjustment
2023-03-17 05:33:49 +00:00
"""
2023-03-09 00:26:47 +00:00
2023-03-17 05:33:49 +00:00
settings [ ' validation_batch_size ' ] = int ( settings [ ' batch_size ' ] / settings [ ' gradient_accumulation_size ' ] )
2023-03-09 00:26:47 +00:00
if not os . path . exists ( settings [ ' validation_path ' ] ) :
settings [ ' validation_enabled ' ] = False
messages . append ( " Validation not found, disabling validation... " )
elif settings [ ' validation_batch_size ' ] == 0 :
settings [ ' validation_enabled ' ] = False
messages . append ( " Validation batch size == 0, disabling validation... " )
else :
with open ( settings [ ' validation_path ' ] , ' r ' , encoding = " utf-8 " ) as f :
validation_lines = len ( f . readlines ( ) )
if validation_lines < settings [ ' validation_batch_size ' ] :
settings [ ' validation_batch_size ' ] = validation_lines
messages . append ( f " Batch size exceeds validation dataset size, clamping validation batch size to { validation_lines } " )
2023-03-15 00:54:27 +00:00
settings [ ' tokenizer_json ' ] = args . tokenizer_json if args . tokenizer_json else get_tokenizer_jsons ( ) [ 0 ]
2023-03-09 00:26:47 +00:00
if settings [ ' gpus ' ] > get_device_count ( ) :
settings [ ' gpus ' ] = get_device_count ( )
2023-02-18 14:51:00 +00:00
2023-03-09 19:42:31 +00:00
# what an utter mistake this was
2023-03-09 23:04:02 +00:00
settings [ ' optimizer ' ] = ' adamw ' # if settings['gpus'] == 1 else 'adamw_zero'
2023-03-09 03:26:18 +00:00
2023-03-09 00:26:47 +00:00
if ' learning_rate_scheme ' not in settings or settings [ ' learning_rate_scheme ' ] not in LEARNING_RATE_SCHEMES :
2023-03-09 14:17:01 +00:00
settings [ ' learning_rate_scheme ' ] = " Multistep "
settings [ ' learning_rate_scheme ' ] = LEARNING_RATE_SCHEMES [ settings [ ' learning_rate_scheme ' ] ]
2023-03-09 00:26:47 +00:00
learning_rate_schema = [ f " default_lr_scheme: { settings [ ' learning_rate_scheme ' ] } " ]
if settings [ ' learning_rate_scheme ' ] == " MultiStepLR " :
if not settings [ ' learning_rate_schedule ' ] :
2023-03-09 14:17:01 +00:00
settings [ ' learning_rate_schedule ' ] = LEARNING_RATE_SCHEDULE
2023-03-09 00:26:47 +00:00
elif isinstance ( settings [ ' learning_rate_schedule ' ] , str ) :
settings [ ' learning_rate_schedule ' ] = json . loads ( settings [ ' learning_rate_schedule ' ] )
2023-03-08 15:31:33 +00:00
2023-03-09 05:54:08 +00:00
settings [ ' learning_rate_schedule ' ] = schedule_learning_rate ( iterations_per_epoch , settings [ ' learning_rate_schedule ' ] )
2023-03-09 00:26:47 +00:00
learning_rate_schema . append ( f " gen_lr_steps: { settings [ ' learning_rate_schedule ' ] } " )
2023-03-08 15:31:33 +00:00
learning_rate_schema . append ( f " lr_gamma: 0.5 " )
2023-03-09 00:26:47 +00:00
elif settings [ ' learning_rate_scheme ' ] == " CosineAnnealingLR_Restart " :
2023-03-09 05:54:08 +00:00
epochs = settings [ ' epochs ' ]
2023-03-09 14:17:01 +00:00
restarts = settings [ ' learning_rate_restarts ' ]
restart_period = int ( epochs / restarts )
2023-03-09 05:54:08 +00:00
if ' learning_rate_warmup ' not in settings :
settings [ ' learning_rate_warmup ' ] = 0
if ' learning_rate_min ' not in settings :
2023-03-09 14:17:01 +00:00
settings [ ' learning_rate_min ' ] = 1e-08
if ' learning_rate_period ' not in settings :
settings [ ' learning_rate_period ' ] = [ iterations_per_epoch * restart_period for x in range ( epochs ) ]
settings [ ' learning_rate_restarts ' ] = [ iterations_per_epoch * ( x + 1 ) * restart_period for x in range ( restarts ) ] # [52, 104, 156, 208]
2023-03-09 05:54:08 +00:00
if ' learning_rate_restart_weights ' not in settings :
settings [ ' learning_rate_restart_weights ' ] = [ ( restarts - x - 1 ) / restarts for x in range ( restarts ) ] # [.75, .5, .25, .125]
settings [ ' learning_rate_restart_weights ' ] [ - 1 ] = settings [ ' learning_rate_restart_weights ' ] [ - 2 ] * 0.5
learning_rate_schema . append ( f " T_period: { settings [ ' learning_rate_period ' ] } " )
2023-03-09 14:17:01 +00:00
learning_rate_schema . append ( f " warmup: { settings [ ' learning_rate_warmup ' ] } " )
2023-03-09 05:54:08 +00:00
learning_rate_schema . append ( f " eta_min: !!float { settings [ ' learning_rate_min ' ] } " )
learning_rate_schema . append ( f " restarts: { settings [ ' learning_rate_restarts ' ] } " )
learning_rate_schema . append ( f " restart_weights: { settings [ ' learning_rate_restart_weights ' ] } " )
2023-03-08 15:31:33 +00:00
settings [ ' learning_rate_scheme ' ] = " \n " . join ( learning_rate_schema )
2023-03-09 00:26:47 +00:00
if settings [ ' resume_state ' ] :
2023-03-10 03:48:46 +00:00
settings [ ' source_model ' ] = f " # pretrain_model_gpt: ' { settings [ ' source_model ' ] } ' "
settings [ ' resume_state ' ] = f " resume_state: ' { settings [ ' resume_state ' ] } ' "
2023-03-09 00:26:47 +00:00
else :
2023-03-10 03:48:46 +00:00
settings [ ' source_model ' ] = f " pretrain_model_gpt: ' { settings [ ' source_model ' ] } ' "
settings [ ' resume_state ' ] = f " # resume_state: ' { settings [ ' resume_state ' ] } ' "
2023-02-17 03:05:27 +00:00
2023-03-14 05:02:14 +00:00
def use_template ( template , out ) :
with open ( template , ' r ' , encoding = " utf-8 " ) as f :
yaml = f . read ( )
2023-02-17 03:05:27 +00:00
2023-03-14 05:02:14 +00:00
# i could just load and edit the YAML directly, but this is easier, as I don't need to bother with path traversals
for k in settings :
if settings [ k ] is None :
continue
yaml = yaml . replace ( f " $ {{ { k } }} " , str ( settings [ k ] ) )
2023-02-18 02:07:22 +00:00
2023-03-14 05:02:14 +00:00
with open ( out , ' w ' , encoding = " utf-8 " ) as f :
f . write ( yaml )
2023-03-09 00:26:47 +00:00
2023-03-14 15:48:09 +00:00
if args . tts_backend == " tortoise " :
use_template ( f ' ./models/.template.dlas.yaml ' , f ' ./training/ { settings [ " voice " ] } /train.yaml ' )
elif args . tts_backend == " vall-e " :
2023-03-22 17:47:23 +00:00
settings [ ' model_name ' ] = " [ ' ar-quarter ' , ' nar-quarter ' ] "
use_template ( f ' ./models/.template.valle.yaml ' , f ' ./training/ { settings [ " voice " ] } /config.yaml ' )
2023-03-14 05:02:14 +00:00
messages . append ( f " Saved training output " )
2023-03-09 00:26:47 +00:00
return settings , messages
2023-02-18 02:07:22 +00:00
def import_voices ( files , saveAs = None , progress = None ) :
2023-02-17 03:05:27 +00:00
global args
2023-02-18 02:07:22 +00:00
if not isinstance ( files , list ) :
files = [ files ]
2023-02-17 03:05:27 +00:00
2023-05-04 23:40:33 +00:00
for file in tqdm ( files , desc = " Importing voice files " ) :
2023-02-18 02:07:22 +00:00
j , latents = read_generate_settings ( file , read_latents = True )
if j is not None and saveAs is None :
saveAs = j [ ' voice ' ]
if saveAs is None or saveAs == " " :
raise Exception ( " Specify a voice name " )
outdir = f ' { get_voice_dir ( ) } / { saveAs } / '
os . makedirs ( outdir , exist_ok = True )
if latents :
print ( f " Importing latents to { latents } " )
with open ( f ' { outdir } /cond_latents.pth ' , ' wb ' ) as f :
f . write ( latents )
latents = f ' { outdir } /cond_latents.pth '
print ( f " Imported latents to { latents } " )
2023-02-17 03:05:27 +00:00
else :
2023-02-18 02:07:22 +00:00
filename = file . name
if filename [ - 4 : ] != " .wav " :
raise Exception ( " Please convert to a WAV first " )
path = f " { outdir } / { os . path . basename ( filename ) } "
print ( f " Importing voice to { path } " )
2023-03-12 23:39:00 +00:00
waveform , sample_rate = torchaudio . load ( filename )
2023-02-18 02:07:22 +00:00
2023-02-20 00:21:16 +00:00
if args . voice_fixer :
if not voicefixer :
load_voicefixer ( )
2023-03-13 01:20:55 +00:00
waveform , sample_rate = resample ( waveform , sample_rate , 44100 )
2023-03-12 23:39:00 +00:00
torchaudio . save ( path , waveform , sample_rate )
2023-02-18 02:07:22 +00:00
print ( f " Running ' voicefixer ' on voice sample: { path } " )
voicefixer . restore (
input = path ,
output = path ,
cuda = get_device_name ( ) == " cuda " and args . voice_fixer_use_cuda ,
#mode=mode,
)
else :
2023-03-12 23:39:00 +00:00
torchaudio . save ( path , waveform , sample_rate )
2023-02-17 03:05:27 +00:00
2023-02-18 02:07:22 +00:00
print ( f " Imported voice to { path } " )
2023-02-17 03:05:27 +00:00
2023-03-15 00:37:38 +00:00
def relative_paths ( dirs ) :
return [ ' ./ ' + os . path . relpath ( d ) . replace ( " \\ " , " / " ) for d in dirs ]
2023-04-13 21:10:38 +00:00
def get_voice ( name , dir = get_voice_dir ( ) , load_latents = True ) :
subj = f ' { dir } / { name } / '
if not os . path . isdir ( subj ) :
return
voice = list ( glob ( f ' { subj } /*.wav ' ) ) + list ( glob ( f ' { subj } /*.mp3 ' ) ) + list ( glob ( f ' { subj } /*.flac ' ) )
if load_latents :
voice = voice + list ( glob ( f ' { subj } /*.pth ' ) )
return sorted ( voice )
2023-02-21 21:50:05 +00:00
def get_voice_list ( dir = get_voice_dir ( ) , append_defaults = False ) :
2023-03-09 04:23:36 +00:00
defaults = [ " random " , " microphone " ]
2023-02-20 00:21:16 +00:00
os . makedirs ( dir , exist_ok = True )
2023-04-13 21:10:38 +00:00
#res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
res = [ ]
for name in os . listdir ( dir ) :
if name in defaults :
continue
if not os . path . isdir ( f ' { dir } / { name } ' ) :
continue
if len ( os . listdir ( os . path . join ( dir , name ) ) ) == 0 :
continue
files = get_voice ( name , dir = dir )
if len ( files ) > 0 :
res . append ( name )
else :
for subdir in os . listdir ( f ' { dir } / { name } ' ) :
if not os . path . isdir ( f ' { dir } / { name } / { subdir } ' ) :
continue
files = get_voice ( f ' { name } / { subdir } ' , dir = dir )
if len ( files ) == 0 :
continue
res . append ( f ' { name } / { subdir } ' )
res = sorted ( res )
2023-02-21 21:50:05 +00:00
if append_defaults :
2023-03-09 04:23:36 +00:00
res = res + defaults
2023-04-13 21:10:38 +00:00
2023-02-21 21:50:05 +00:00
return res
2023-02-17 03:05:27 +00:00
2023-03-31 03:26:00 +00:00
def get_valle_models ( dir = " ./training/ " ) :
return [ f ' { dir } / { d } /config.yaml ' for d in os . listdir ( dir ) if os . path . exists ( f ' { dir } / { d } /config.yaml ' ) ]
2023-03-15 00:37:38 +00:00
2023-03-29 19:53:23 +00:00
def get_autoregressive_models ( dir = " ./models/finetunes/ " , prefixed = False , auto = False ) :
2023-02-20 00:21:16 +00:00
os . makedirs ( dir , exist_ok = True )
2023-02-21 21:50:05 +00:00
base = [ get_model_path ( ' autoregressive.pth ' ) ]
halfp = get_halfp_model_path ( )
if os . path . exists ( halfp ) :
base . append ( halfp )
2023-02-24 12:58:41 +00:00
additionals = sorted ( [ f ' { dir } / { d } ' for d in os . listdir ( dir ) if d [ - 4 : ] == " .pth " ] )
found = [ ]
for training in os . listdir ( f ' ./training/ ' ) :
2023-03-09 02:29:08 +00:00
if not os . path . isdir ( f ' ./training/ { training } / ' ) or not os . path . isdir ( f ' ./training/ { training } /finetune/ ' ) or not os . path . isdir ( f ' ./training/ { training } /finetune/models/ ' ) :
2023-02-24 12:58:41 +00:00
continue
2023-03-09 02:29:08 +00:00
models = sorted ( [ int ( d [ : - 8 ] ) for d in os . listdir ( f ' ./training/ { training } /finetune/models/ ' ) if d [ - 8 : ] == " _gpt.pth " ] )
found = found + [ f ' ./training/ { training } /finetune/models/ { d } _gpt.pth ' for d in models ]
2023-02-24 12:58:41 +00:00
2023-02-28 15:36:06 +00:00
res = base + additionals + found
if prefixed :
for i in range ( len ( res ) ) :
path = res [ i ]
hash = hash_file ( path )
shorthash = hash [ : 8 ]
res [ i ] = f ' [ { shorthash } ] { path } '
2023-03-29 19:53:23 +00:00
paths = relative_paths ( res )
if auto :
paths = [ " auto " ] + paths
return paths
2023-03-15 00:37:38 +00:00
def get_diffusion_models ( dir = " ./models/finetunes/ " , prefixed = False ) :
return relative_paths ( [ get_model_path ( ' diffusion_decoder.pth ' ) ] )
def get_tokenizer_jsons ( dir = " ./models/tokenizers/ " ) :
2023-03-15 01:09:20 +00:00
additionals = sorted ( [ f ' { dir } / { d } ' for d in os . listdir ( dir ) if d [ - 5 : ] == " .json " ] ) if os . path . isdir ( dir ) else [ ]
2023-03-15 00:37:38 +00:00
return relative_paths ( [ " ./modules/tortoise-tts/tortoise/data/tokenizer.json " ] + additionals )
2023-03-18 15:14:22 +00:00
def tokenize_text ( text , config = None , stringed = True , skip_specials = False ) :
2023-03-15 00:37:38 +00:00
from tortoise . utils . tokenizer import VoiceBpeTokenizer
2023-03-18 15:14:22 +00:00
if not config :
config = args . tokenizer_json if args . tokenizer_json else get_tokenizer_jsons ( ) [ 0 ]
2023-03-15 00:37:38 +00:00
if not tts :
2023-03-18 15:14:22 +00:00
tokenizer = VoiceBpeTokenizer ( config )
2023-03-17 02:08:07 +00:00
else :
2023-03-17 13:08:34 +00:00
tokenizer = tts . tokenizer
2023-03-17 02:08:07 +00:00
encoded = tokenizer . encode ( text )
2023-03-17 05:33:49 +00:00
decoded = tokenizer . tokenizer . decode ( encoded , skip_special_tokens = skip_specials ) . split ( " " )
2023-03-15 00:37:38 +00:00
2023-03-17 02:08:07 +00:00
if stringed :
return " \n " . join ( [ str ( encoded ) , str ( decoded ) ] )
2023-03-15 00:37:38 +00:00
2023-03-17 02:08:07 +00:00
return decoded
2023-02-20 00:21:16 +00:00
def get_dataset_list ( dir = " ./training/ " ) :
2023-03-09 00:26:47 +00:00
return sorted ( [ d for d in os . listdir ( dir ) if os . path . isdir ( os . path . join ( dir , d ) ) and " train.txt " in os . listdir ( os . path . join ( dir , d ) ) ] )
2023-02-20 00:21:16 +00:00
def get_training_list ( dir = " ./training/ " ) :
2023-03-14 15:48:09 +00:00
if args . tts_backend == " tortoise " :
return sorted ( [ f ' ./training/ { d } /train.yaml ' for d in os . listdir ( dir ) if os . path . isdir ( os . path . join ( dir , d ) ) and " train.yaml " in os . listdir ( os . path . join ( dir , d ) ) ] )
2023-03-23 15:42:51 +00:00
else :
return sorted ( [ f ' ./training/ { d } /config.yaml ' for d in os . listdir ( dir ) if os . path . isdir ( os . path . join ( dir , d ) ) and " config.yaml " in os . listdir ( os . path . join ( dir , d ) ) ] )
2023-02-20 00:21:16 +00:00
def pad ( num , zeroes ) :
return str ( num ) . zfill ( zeroes + 1 )
2023-02-17 03:05:27 +00:00
def curl ( url ) :
try :
req = urllib . request . Request ( url , headers = { ' User-Agent ' : ' Python ' } )
conn = urllib . request . urlopen ( req )
data = conn . read ( )
data = data . decode ( )
data = json . loads ( data )
conn . close ( )
return data
except Exception as e :
print ( e )
return None
2023-03-07 17:04:45 +00:00
def check_for_updates ( dir = None ) :
if dir is None :
2023-03-07 19:33:56 +00:00
check_for_updates ( " ./.git/ " )
check_for_updates ( " ./.git/modules/dlas/ " )
check_for_updates ( " ./.git/modules/tortoise-tts/ " )
2023-03-07 17:04:45 +00:00
return
2023-03-07 19:33:56 +00:00
git_dir = dir
2023-03-07 17:04:45 +00:00
if not os . path . isfile ( f ' { git_dir } /FETCH_HEAD ' ) :
2023-03-07 19:33:56 +00:00
print ( f " Cannot check for updates for { dir } : not from a git repo " )
2023-02-17 03:05:27 +00:00
return False
2023-03-07 17:04:45 +00:00
with open ( f ' { git_dir } /FETCH_HEAD ' , ' r ' , encoding = " utf-8 " ) as f :
2023-02-17 03:05:27 +00:00
head = f . read ( )
match = re . findall ( r " ^([a-f0-9]+).+?https: \ / \ /(.+?) \ /(.+?) \ /(.+?) \ n " , head )
if match is None or len ( match ) == 0 :
2023-03-07 19:33:56 +00:00
print ( f " Cannot check for updates for { dir } : cannot parse FETCH_HEAD " )
2023-02-17 03:05:27 +00:00
return False
match = match [ 0 ]
local = match [ 0 ]
host = match [ 1 ]
owner = match [ 2 ]
repo = match [ 3 ]
res = curl ( f " https:// { host } /api/v1/repos/ { owner } / { repo } /branches/ " ) #this only works for gitea instances
if res is None or len ( res ) == 0 :
2023-03-07 19:33:56 +00:00
print ( f " Cannot check for updates for { dir } : cannot fetch from remote " )
2023-02-17 03:05:27 +00:00
return False
remote = res [ 0 ] [ " commit " ] [ " id " ]
if remote != local :
2023-03-07 19:33:56 +00:00
print ( f " New version found for { dir } : { local [ : 8 ] } => { remote [ : 8 ] } " )
2023-02-17 03:05:27 +00:00
return True
return False
2023-02-20 00:21:16 +00:00
def notify_progress ( message , progress = None , verbose = True ) :
if verbose :
print ( message )
2023-02-17 03:05:27 +00:00
2023-02-20 00:21:16 +00:00
if progress is None :
2023-05-05 12:36:48 +00:00
tqdm . write ( message )
2023-05-04 23:40:33 +00:00
else :
progress ( 0 , desc = message )
2023-02-18 14:10:26 +00:00
2023-02-20 00:21:16 +00:00
def get_args ( ) :
global args
return args
2023-02-18 14:10:26 +00:00
2023-02-20 00:21:16 +00:00
def setup_args ( ) :
global args
2023-02-19 01:47:06 +00:00
2023-02-20 00:21:16 +00:00
default_arguments = {
' share ' : False ,
' listen ' : None ,
' check-for-updates ' : False ,
' models-from-local-only ' : False ,
' low-vram ' : False ,
' sample-batch-size ' : None ,
2023-03-21 21:34:26 +00:00
' unsqueeze-sample-batches ' : False ,
2023-02-20 00:21:16 +00:00
' embed-output-metadata ' : True ,
' latents-lean-and-mean ' : True ,
' voice-fixer ' : False , # getting tired of long initialization times in a Colab for downloading a large dataset for it
' voice-fixer-use-cuda ' : True ,
2023-03-22 19:24:53 +00:00
2023-03-14 05:02:14 +00:00
2023-02-20 00:21:16 +00:00
' force-cpu-for-conditioning-latents ' : False ,
' defer-tts-load ' : False ,
' device-override ' : None ,
2023-02-21 21:50:05 +00:00
' prune-nonfinal-outputs ' : True ,
2023-02-20 00:21:16 +00:00
' concurrency-count ' : 2 ,
2023-03-14 05:02:14 +00:00
' autocalculate-voice-chunk-duration-size ' : 10 ,
2023-02-20 00:21:16 +00:00
' output-sample-rate ' : 44100 ,
' output-volume ' : 1 ,
2023-03-19 22:03:41 +00:00
' results-folder ' : " ./results/ " ,
2023-02-27 19:20:06 +00:00
2023-03-22 19:24:53 +00:00
' hf-token ' : None ,
2023-03-14 05:02:14 +00:00
' tts-backend ' : TTSES [ 0 ] ,
2023-02-27 19:20:06 +00:00
' autoregressive-model ' : None ,
2023-03-15 00:37:38 +00:00
' diffusion-model ' : None ,
2023-03-14 05:02:14 +00:00
' vocoder-model ' : VOCODERS [ - 1 ] ,
2023-03-15 00:37:38 +00:00
' tokenizer-json ' : None ,
2023-03-14 05:02:14 +00:00
2023-03-17 01:24:02 +00:00
' phonemizer-backend ' : ' espeak ' ,
2023-03-31 03:26:00 +00:00
' valle-model ' : None ,
2023-03-17 01:24:02 +00:00
2023-03-06 05:21:33 +00:00
' whisper-backend ' : ' openai/whisper ' ,
2023-02-27 19:20:06 +00:00
' whisper-model ' : " base " ,
2023-03-22 19:24:53 +00:00
' whisper-batchsize ' : 1 ,
2023-02-26 01:57:56 +00:00
' training-default-halfp ' : False ,
' training-default-bnb ' : True ,
2023-02-20 00:21:16 +00:00
}
2023-02-19 01:47:06 +00:00
2023-02-20 00:21:16 +00:00
if os . path . isfile ( ' ./config/exec.json ' ) :
with open ( f ' ./config/exec.json ' , ' r ' , encoding = " utf-8 " ) as f :
2023-02-28 22:13:21 +00:00
try :
overrides = json . load ( f )
for k in overrides :
default_arguments [ k ] = overrides [ k ]
except Exception as e :
print ( e )
pass
2023-02-19 16:24:06 +00:00
2023-02-20 00:21:16 +00:00
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --share " , action = ' store_true ' , default = default_arguments [ ' share ' ] , help = " Lets Gradio return a public URL to use anywhere " )
parser . add_argument ( " --listen " , default = default_arguments [ ' listen ' ] , help = " Path for Gradio to listen on " )
parser . add_argument ( " --check-for-updates " , action = ' store_true ' , default = default_arguments [ ' check-for-updates ' ] , help = " Checks for update on startup " )
parser . add_argument ( " --models-from-local-only " , action = ' store_true ' , default = default_arguments [ ' models-from-local-only ' ] , help = " Only loads models from disk, does not check for updates for models " )
parser . add_argument ( " --low-vram " , action = ' store_true ' , default = default_arguments [ ' low-vram ' ] , help = " Disables some optimizations that increases VRAM usage " )
parser . add_argument ( " --no-embed-output-metadata " , action = ' store_false ' , default = not default_arguments [ ' embed-output-metadata ' ] , help = " Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag) " )
parser . add_argument ( " --latents-lean-and-mean " , action = ' store_true ' , default = default_arguments [ ' latents-lean-and-mean ' ] , help = " Exports the bare essentials for latents. " )
parser . add_argument ( " --voice-fixer " , action = ' store_true ' , default = default_arguments [ ' voice-fixer ' ] , help = " Uses python module ' voicefixer ' to improve audio quality, if available. " )
parser . add_argument ( " --voice-fixer-use-cuda " , action = ' store_true ' , default = default_arguments [ ' voice-fixer-use-cuda ' ] , help = " Hints to voicefixer to use CUDA, if available. " )
parser . add_argument ( " --force-cpu-for-conditioning-latents " , default = default_arguments [ ' force-cpu-for-conditioning-latents ' ] , action = ' store_true ' , help = " Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts) " )
parser . add_argument ( " --defer-tts-load " , default = default_arguments [ ' defer-tts-load ' ] , action = ' store_true ' , help = " Defers loading TTS model " )
2023-02-21 21:50:05 +00:00
parser . add_argument ( " --prune-nonfinal-outputs " , default = default_arguments [ ' prune-nonfinal-outputs ' ] , action = ' store_true ' , help = " Deletes non-final output files on completing a generation " )
2023-02-20 00:21:16 +00:00
parser . add_argument ( " --device-override " , default = default_arguments [ ' device-override ' ] , help = " A device string to override pass through Torch " )
parser . add_argument ( " --sample-batch-size " , default = default_arguments [ ' sample-batch-size ' ] , type = int , help = " Sets how many batches to use during the autoregressive samples pass " )
2023-03-21 22:18:57 +00:00
parser . add_argument ( " --unsqueeze-sample-batches " , default = default_arguments [ ' unsqueeze-sample-batches ' ] , action = ' store_true ' , help = " Unsqueezes sample batches to process one by one after sampling " )
2023-02-20 00:21:16 +00:00
parser . add_argument ( " --concurrency-count " , type = int , default = default_arguments [ ' concurrency-count ' ] , help = " How many Gradio events to process at once " )
2023-03-03 21:13:48 +00:00
parser . add_argument ( " --autocalculate-voice-chunk-duration-size " , type = float , default = default_arguments [ ' autocalculate-voice-chunk-duration-size ' ] , help = " Number of seconds to suggest voice chunk size for (for example, 100 seconds of audio at 10 seconds per chunk will suggest 10 chunks) " )
2023-02-20 00:21:16 +00:00
parser . add_argument ( " --output-sample-rate " , type = int , default = default_arguments [ ' output-sample-rate ' ] , help = " Sample rate to resample the output to (from 24KHz) " )
parser . add_argument ( " --output-volume " , type = float , default = default_arguments [ ' output-volume ' ] , help = " Adjusts volume of output " )
2023-03-19 22:03:41 +00:00
parser . add_argument ( " --results-folder " , type = str , default = default_arguments [ ' results-folder ' ] , help = " Sets output directory " )
2023-02-20 00:21:16 +00:00
2023-03-22 19:24:53 +00:00
parser . add_argument ( " --hf-token " , type = str , default = default_arguments [ ' hf-token ' ] , help = " HuggingFace Token " )
2023-03-14 05:02:14 +00:00
parser . add_argument ( " --tts-backend " , default = default_arguments [ ' tts-backend ' ] , help = " Specifies which TTS backend to use. " )
2023-03-15 00:37:38 +00:00
2023-02-27 19:20:06 +00:00
parser . add_argument ( " --autoregressive-model " , default = default_arguments [ ' autoregressive-model ' ] , help = " Specifies which autoregressive model to use for sampling. " )
2023-03-15 00:37:38 +00:00
parser . add_argument ( " --diffusion-model " , default = default_arguments [ ' diffusion-model ' ] , help = " Specifies which diffusion model to use for sampling. " )
parser . add_argument ( " --vocoder-model " , default = default_arguments [ ' vocoder-model ' ] , action = ' store_true ' , help = " Specifies with vocoder to use " )
parser . add_argument ( " --tokenizer-json " , default = default_arguments [ ' tokenizer-json ' ] , help = " Specifies which tokenizer json to use for tokenizing. " )
2023-03-17 01:24:02 +00:00
parser . add_argument ( " --phonemizer-backend " , default = default_arguments [ ' phonemizer-backend ' ] , help = " Specifies which phonemizer backend to use. " )
2023-03-31 03:26:00 +00:00
parser . add_argument ( " --valle-model " , default = default_arguments [ ' valle-model ' ] , help = " Specifies which VALL-E model to use for sampling. " )
2023-03-11 16:40:34 +00:00
parser . add_argument ( " --whisper-backend " , default = default_arguments [ ' whisper-backend ' ] , action = ' store_true ' , help = " Picks which whisper backend to use (openai/whisper, lightmare/whispercpp) " )
2023-02-27 19:20:06 +00:00
parser . add_argument ( " --whisper-model " , default = default_arguments [ ' whisper-model ' ] , help = " Specifies which whisper model to use for transcription. " )
2023-03-22 19:24:53 +00:00
parser . add_argument ( " --whisper-batchsize " , type = int , default = default_arguments [ ' whisper-batchsize ' ] , help = " Specifies batch size for WhisperX " )
2023-02-27 19:20:06 +00:00
2023-02-26 01:57:56 +00:00
parser . add_argument ( " --training-default-halfp " , action = ' store_true ' , default = default_arguments [ ' training-default-halfp ' ] , help = " Training default: halfp " )
parser . add_argument ( " --training-default-bnb " , action = ' store_true ' , default = default_arguments [ ' training-default-bnb ' ] , help = " Training default: bnb " )
2023-02-20 00:21:16 +00:00
parser . add_argument ( " --os " , default = " unix " , help = " Specifies which OS, easily " )
args = parser . parse_args ( )
2023-02-18 14:10:26 +00:00
2023-02-20 00:21:16 +00:00
args . embed_output_metadata = not args . no_embed_output_metadata
2023-02-19 05:10:08 +00:00
2023-02-20 00:21:16 +00:00
if not args . device_override :
set_device_name ( args . device_override )
2023-02-19 05:10:08 +00:00
2023-03-15 01:20:15 +00:00
if args . sample_batch_size == 0 and get_device_batch_size ( ) == 1 :
print ( " !WARNING! Automatically deduced sample batch size returned 1. " )
2023-03-03 04:37:18 +00:00
2023-02-20 00:21:16 +00:00
args . listen_host = None
args . listen_port = None
args . listen_path = None
if args . listen :
try :
2023-03-03 04:37:18 +00:00
match = re . findall ( r " ^(?:(.+?):( \ d+))?( \ /.*?)?$ " , args . listen ) [ 0 ]
2023-02-19 05:10:08 +00:00
2023-02-20 00:21:16 +00:00
args . listen_host = match [ 0 ] if match [ 0 ] != " " else " 127.0.0.1 "
args . listen_port = match [ 1 ] if match [ 1 ] != " " else None
args . listen_path = match [ 2 ] if match [ 2 ] != " " else " / "
except Exception as e :
pass
2023-02-19 05:10:08 +00:00
2023-02-20 00:21:16 +00:00
if args . listen_port is not None :
args . listen_port = int ( args . listen_port )
2023-03-03 18:51:33 +00:00
if args . listen_port == 0 :
args . listen_port = None
2023-02-18 20:37:37 +00:00
2023-02-20 00:21:16 +00:00
return args
2023-02-18 14:10:26 +00:00
2023-03-14 05:02:14 +00:00
def get_default_settings ( hypenated = True ) :
settings = {
' listen ' : None if not args . listen else args . listen ,
' share ' : args . share ,
' low-vram ' : args . low_vram ,
' check-for-updates ' : args . check_for_updates ,
' models-from-local-only ' : args . models_from_local_only ,
' force-cpu-for-conditioning-latents ' : args . force_cpu_for_conditioning_latents ,
' defer-tts-load ' : args . defer_tts_load ,
' prune-nonfinal-outputs ' : args . prune_nonfinal_outputs ,
' device-override ' : args . device_override ,
' sample-batch-size ' : args . sample_batch_size ,
2023-03-21 21:34:26 +00:00
' unsqueeze-sample-batches ' : args . unsqueeze_sample_batches ,
2023-03-14 05:02:14 +00:00
' embed-output-metadata ' : args . embed_output_metadata ,
' latents-lean-and-mean ' : args . latents_lean_and_mean ,
' voice-fixer ' : args . voice_fixer ,
' voice-fixer-use-cuda ' : args . voice_fixer_use_cuda ,
' concurrency-count ' : args . concurrency_count ,
' output-sample-rate ' : args . output_sample_rate ,
' autocalculate-voice-chunk-duration-size ' : args . autocalculate_voice_chunk_duration_size ,
' output-volume ' : args . output_volume ,
2023-03-19 22:03:41 +00:00
' results-folder ' : args . results_folder ,
2023-03-14 05:02:14 +00:00
2023-03-22 19:24:53 +00:00
' hf-token ' : args . hf_token ,
2023-03-14 05:02:14 +00:00
' tts-backend ' : args . tts_backend ,
' autoregressive-model ' : args . autoregressive_model ,
2023-03-15 00:37:38 +00:00
' diffusion-model ' : args . diffusion_model ,
2023-03-14 05:02:14 +00:00
' vocoder-model ' : args . vocoder_model ,
2023-03-15 00:37:38 +00:00
' tokenizer-json ' : args . tokenizer_json ,
2023-03-14 05:02:14 +00:00
2023-03-17 01:24:02 +00:00
' phonemizer-backend ' : args . phonemizer_backend ,
2023-03-31 03:26:00 +00:00
' valle-model ' : args . valle_model ,
2023-03-17 01:24:02 +00:00
2023-03-14 05:02:14 +00:00
' whisper-backend ' : args . whisper_backend ,
' whisper-model ' : args . whisper_model ,
2023-03-22 19:24:53 +00:00
' whisper-batchsize ' : args . whisper_batchsize ,
2023-03-14 05:02:14 +00:00
' training-default-halfp ' : args . training_default_halfp ,
' training-default-bnb ' : args . training_default_bnb ,
}
res = { }
for k in settings :
res [ k . replace ( " - " , " _ " ) if not hypenated else k ] = settings [ k ]
return res
2023-03-09 00:26:47 +00:00
def update_args ( * * kwargs ) :
2023-02-17 03:05:27 +00:00
global args
2023-03-14 05:02:14 +00:00
settings = get_default_settings ( hypenated = False )
2023-03-09 00:26:47 +00:00
settings . update ( kwargs )
args . listen = settings [ ' listen ' ]
args . share = settings [ ' share ' ]
args . check_for_updates = settings [ ' check_for_updates ' ]
args . models_from_local_only = settings [ ' models_from_local_only ' ]
args . low_vram = settings [ ' low_vram ' ]
args . force_cpu_for_conditioning_latents = settings [ ' force_cpu_for_conditioning_latents ' ]
args . defer_tts_load = settings [ ' defer_tts_load ' ]
args . prune_nonfinal_outputs = settings [ ' prune_nonfinal_outputs ' ]
args . device_override = settings [ ' device_override ' ]
args . sample_batch_size = settings [ ' sample_batch_size ' ]
2023-03-21 21:34:26 +00:00
args . unsqueeze_sample_batches = settings [ ' unsqueeze_sample_batches ' ]
2023-03-09 00:26:47 +00:00
args . embed_output_metadata = settings [ ' embed_output_metadata ' ]
args . latents_lean_and_mean = settings [ ' latents_lean_and_mean ' ]
args . voice_fixer = settings [ ' voice_fixer ' ]
args . voice_fixer_use_cuda = settings [ ' voice_fixer_use_cuda ' ]
args . concurrency_count = settings [ ' concurrency_count ' ]
2023-03-03 21:13:48 +00:00
args . output_sample_rate = 44000
2023-03-09 00:26:47 +00:00
args . autocalculate_voice_chunk_duration_size = settings [ ' autocalculate_voice_chunk_duration_size ' ]
args . output_volume = settings [ ' output_volume ' ]
2023-03-19 22:03:41 +00:00
args . results_folder = settings [ ' results_folder ' ]
2023-02-27 19:20:06 +00:00
2023-03-22 19:24:53 +00:00
args . hf_token = settings [ ' hf_token ' ]
2023-03-14 05:02:14 +00:00
args . tts_backend = settings [ ' tts_backend ' ]
2023-03-15 00:37:38 +00:00
2023-03-09 00:26:47 +00:00
args . autoregressive_model = settings [ ' autoregressive_model ' ]
2023-03-15 00:37:38 +00:00
args . diffusion_model = settings [ ' diffusion_model ' ]
2023-03-09 00:26:47 +00:00
args . vocoder_model = settings [ ' vocoder_model ' ]
2023-03-15 00:37:38 +00:00
args . tokenizer_json = settings [ ' tokenizer_json ' ]
2023-03-14 05:02:14 +00:00
2023-03-17 01:24:02 +00:00
args . phonemizer_backend = settings [ ' phonemizer_backend ' ]
2023-03-31 03:26:00 +00:00
args . valle_model = settings [ ' valle_model ' ]
2023-03-17 01:24:02 +00:00
2023-03-09 00:26:47 +00:00
args . whisper_backend = settings [ ' whisper_backend ' ]
args . whisper_model = settings [ ' whisper_model ' ]
2023-03-22 19:24:53 +00:00
args . whisper_batchsize = settings [ ' whisper_batchsize ' ]
2023-02-27 19:20:06 +00:00
2023-03-09 00:26:47 +00:00
args . training_default_halfp = settings [ ' training_default_halfp ' ]
args . training_default_bnb = settings [ ' training_default_bnb ' ]
2023-02-17 03:05:27 +00:00
2023-02-18 14:10:26 +00:00
save_args_settings ( )
def save_args_settings ( ) :
2023-02-26 01:57:56 +00:00
global args
2023-03-14 05:02:14 +00:00
settings = get_default_settings ( )
2023-02-17 03:05:27 +00:00
2023-02-17 20:10:27 +00:00
os . makedirs ( ' ./config/ ' , exist_ok = True )
2023-02-17 03:05:27 +00:00
with open ( f ' ./config/exec.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( settings , indent = ' \t ' ) )
2023-03-09 00:26:47 +00:00
# super kludgy )`;
2023-03-12 15:49:50 +00:00
def import_generate_settings ( file = None ) :
if not file :
file = " ./config/generate.json "
2023-03-09 18:34:52 +00:00
res = {
2023-03-09 00:26:47 +00:00
' text ' : None ,
' delimiter ' : None ,
' emotion ' : None ,
' prompt ' : None ,
2023-03-14 18:46:20 +00:00
' voice ' : " random " ,
2023-03-09 00:26:47 +00:00
' mic_audio ' : None ,
' voice_latents_chunks ' : None ,
' candidates ' : None ,
' seed ' : None ,
' num_autoregressive_samples ' : 16 ,
' diffusion_iterations ' : 30 ,
' temperature ' : 0.8 ,
' diffusion_sampler ' : " DDIM " ,
' breathing_room ' : 8 ,
' cvvp_weight ' : 0.0 ,
' top_p ' : 0.8 ,
' diffusion_temperature ' : 1.0 ,
' length_penalty ' : 1.0 ,
' repetition_penalty ' : 2.0 ,
' cond_free_k ' : 2.0 ,
' experimentals ' : None ,
}
2023-02-20 00:21:16 +00:00
settings , _ = read_generate_settings ( file , read_latents = False )
2023-03-12 15:49:50 +00:00
2023-03-09 18:34:52 +00:00
if settings is not None :
res . update ( settings )
2023-03-12 15:49:50 +00:00
2023-03-09 18:34:52 +00:00
return res
2023-02-20 00:21:16 +00:00
2023-03-14 18:46:20 +00:00
def reset_generate_settings ( ) :
2023-03-10 22:35:32 +00:00
with open ( f ' ./config/generate.json ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( { } , indent = ' \t ' ) )
return import_generate_settings ( )
2023-02-21 03:00:45 +00:00
def read_generate_settings ( file , read_latents = True ) :
2023-02-17 03:05:27 +00:00
j = None
latents = None
2023-02-21 03:00:45 +00:00
if isinstance ( file , list ) and len ( file ) == 1 :
file = file [ 0 ]
2023-02-21 17:35:30 +00:00
try :
if file is not None :
if hasattr ( file , ' name ' ) :
file = file . name
if file [ - 4 : ] == " .wav " :
metadata = music_tag . load_file ( file )
if ' lyrics ' in metadata :
j = json . loads ( str ( metadata [ ' lyrics ' ] ) )
elif file [ - 5 : ] == " .json " :
with open ( file , ' r ' ) as f :
j = json . load ( f )
except Exception as e :
pass
2023-02-17 03:05:27 +00:00
2023-02-28 15:36:06 +00:00
if j is not None :
2023-02-17 03:05:27 +00:00
if ' latents ' in j :
if read_latents :
latents = base64 . b64decode ( j [ ' latents ' ] )
del j [ ' latents ' ]
if " time " in j :
j [ " time " ] = " {:.3f} " . format ( j [ " time " ] )
2023-02-21 03:00:45 +00:00
2023-02-17 03:05:27 +00:00
return (
j ,
latents ,
2023-02-18 02:07:22 +00:00
)
2023-03-06 21:48:34 +00:00
def version_check_tts ( min_version ) :
global tts
if not tts :
raise Exception ( " TTS is not initialized " )
if not hasattr ( tts , ' version ' ) :
return False
if min_version [ 0 ] > tts . version [ 0 ] :
return True
if min_version [ 1 ] > tts . version [ 1 ] :
return True
if min_version [ 2 ] > = tts . version [ 2 ] :
return True
return False
2023-03-31 03:26:00 +00:00
def load_tts ( restart = False ,
# TorToiSe configs
autoregressive_model = None , diffusion_model = None , vocoder_model = None , tokenizer_json = None ,
# VALL-E configs
valle_model = None ,
) :
2023-02-20 00:21:16 +00:00
global args
global tts
2023-02-18 02:07:22 +00:00
2023-02-20 00:21:16 +00:00
if restart :
unload_tts ( )
2023-02-18 02:07:22 +00:00
2023-03-31 03:26:00 +00:00
tts_loading = True
if args . tts_backend == " tortoise " :
if autoregressive_model :
args . autoregressive_model = autoregressive_model
else :
autoregressive_model = args . autoregressive_model
2023-02-21 03:00:45 +00:00
2023-03-31 03:26:00 +00:00
if autoregressive_model == " auto " :
autoregressive_model = deduce_autoregressive_model ( )
2023-02-21 03:00:45 +00:00
2023-03-31 03:26:00 +00:00
if diffusion_model :
args . diffusion_model = diffusion_model
else :
diffusion_model = args . diffusion_model
2023-03-15 00:37:38 +00:00
2023-03-31 03:26:00 +00:00
if vocoder_model :
args . vocoder_model = vocoder_model
else :
vocoder_model = args . vocoder_model
2023-03-15 00:37:38 +00:00
2023-03-31 03:26:00 +00:00
if tokenizer_json :
args . tokenizer_json = tokenizer_json
else :
tokenizer_json = args . tokenizer_json
2023-02-21 03:00:45 +00:00
2023-03-31 03:26:00 +00:00
if get_device_name ( ) == " cpu " :
print ( " !!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch. " )
2023-03-11 16:32:35 +00:00
2023-03-31 03:26:00 +00:00
print ( f " Loading TorToiSe... (AR: { autoregressive_model } , diffusion: { diffusion_model } , vocoder: { vocoder_model } ) " )
tts = TorToise_TTS ( minor_optimizations = not args . low_vram , autoregressive_model_path = autoregressive_model , diffusion_model_path = diffusion_model , vocoder_model = vocoder_model , tokenizer_json = tokenizer_json , unsqueeze_sample_batches = args . unsqueeze_sample_batches )
elif args . tts_backend == " vall-e " :
if valle_model :
args . valle_model = valle_model
else :
valle_model = args . valle_model
2023-02-18 02:07:22 +00:00
2023-03-31 03:26:00 +00:00
print ( f " Loading VALL-E... (Config: { valle_model } ) " )
tts = VALLE_TTS ( config = args . valle_model )
2023-04-26 04:48:09 +00:00
elif args . tts_backend == " bark " :
print ( f " Loading Bark... " )
tts = Bark_TTS ( small = args . low_vram )
2023-02-20 00:21:16 +00:00
2023-03-31 03:26:00 +00:00
print ( " Loaded TTS, ready for generation. " )
tts_loading = False
return tts
2023-02-20 00:21:16 +00:00
def unload_tts ( ) :
global tts
2023-02-18 02:07:22 +00:00
2023-02-20 00:21:16 +00:00
if tts :
del tts
tts = None
2023-03-05 05:17:19 +00:00
print ( " Unloaded TTS " )
2023-02-20 00:21:16 +00:00
do_gc ( )
2023-03-16 14:24:44 +00:00
def reload_tts ( ) :
2023-03-16 14:41:40 +00:00
unload_tts ( )
load_tts ( )
2023-02-20 00:21:16 +00:00
2023-03-07 04:34:39 +00:00
def get_current_voice ( ) :
global current_voice
if current_voice :
return current_voice
settings , _ = read_generate_settings ( " ./config/generate.json " , read_latents = False )
if settings and " voice " in settings [ ' voice ' ] :
return settings [ " voice " ]
return None
def deduce_autoregressive_model ( voice = None ) :
if not voice :
voice = get_current_voice ( )
if voice :
2023-03-09 00:26:47 +00:00
if os . path . exists ( f ' ./models/finetunes/ { voice } .pth ' ) :
return f ' ./models/finetunes/ { voice } .pth '
2023-03-07 04:34:39 +00:00
2023-03-09 00:26:47 +00:00
dir = f ' ./training/ { voice } /finetune/models/ '
2023-03-07 04:34:39 +00:00
if os . path . isdir ( dir ) :
counts = sorted ( [ int ( d [ : - 8 ] ) for d in os . listdir ( dir ) if d [ - 8 : ] == " _gpt.pth " ] )
names = [ f ' { dir } / { d } _gpt.pth ' for d in counts ]
2023-03-09 18:34:52 +00:00
if len ( names ) > 0 :
return names [ - 1 ]
2023-03-07 04:34:39 +00:00
if args . autoregressive_model != " auto " :
return args . autoregressive_model
return get_model_path ( ' autoregressive.pth ' )
2023-02-20 00:21:16 +00:00
def update_autoregressive_model ( autoregressive_model_path ) :
2023-03-31 03:26:00 +00:00
if args . tts_backend != " tortoise " :
raise f " Unsupported backend: { args . tts_backend } "
2023-02-28 15:36:06 +00:00
match = re . findall ( r ' ^ \ [[a-fA-F0-9] {8} \ ] (.+?)$ ' , autoregressive_model_path )
if match :
autoregressive_model_path = match [ 0 ]
2023-02-21 03:00:45 +00:00
if not autoregressive_model_path or not os . path . exists ( autoregressive_model_path ) :
2023-02-24 13:05:08 +00:00
print ( f " Invalid model: { autoregressive_model_path } " )
2023-02-21 03:00:45 +00:00
return
2023-02-20 00:21:16 +00:00
args . autoregressive_model = autoregressive_model_path
save_args_settings ( )
print ( f ' Stored autoregressive model to settings: { autoregressive_model_path } ' )
global tts
if not tts :
2023-02-21 03:00:45 +00:00
if tts_loading :
raise Exception ( " TTS is still initializing... " )
2023-02-27 19:20:06 +00:00
return
2023-03-07 03:55:35 +00:00
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
2023-03-07 04:34:39 +00:00
if autoregressive_model_path == " auto " :
autoregressive_model_path = deduce_autoregressive_model ( )
if autoregressive_model_path == tts . autoregressive_model_path :
return
2023-02-20 00:21:16 +00:00
2023-03-07 02:45:22 +00:00
tts . load_autoregressive_model ( autoregressive_model_path )
2023-02-20 00:21:16 +00:00
2023-03-07 02:45:22 +00:00
do_gc ( )
return autoregressive_model_path
2023-02-20 00:21:16 +00:00
2023-03-15 00:37:38 +00:00
def update_diffusion_model ( diffusion_model_path ) :
2023-03-31 03:26:00 +00:00
if args . tts_backend != " tortoise " :
raise f " Unsupported backend: { args . tts_backend } "
2023-03-15 00:37:38 +00:00
match = re . findall ( r ' ^ \ [[a-fA-F0-9] {8} \ ] (.+?)$ ' , diffusion_model_path )
if match :
diffusion_model_path = match [ 0 ]
if not diffusion_model_path or not os . path . exists ( diffusion_model_path ) :
print ( f " Invalid model: { diffusion_model_path } " )
return
args . diffusion_model = diffusion_model_path
save_args_settings ( )
print ( f ' Stored diffusion model to settings: { diffusion_model_path } ' )
global tts
if not tts :
if tts_loading :
raise Exception ( " TTS is still initializing... " )
return
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
if diffusion_model_path == " auto " :
diffusion_model_path = deduce_diffusion_model ( )
if diffusion_model_path == tts . diffusion_model_path :
return
tts . load_diffusion_model ( diffusion_model_path )
do_gc ( )
return diffusion_model_path
2023-03-07 02:45:22 +00:00
def update_vocoder_model ( vocoder_model ) :
2023-03-31 03:26:00 +00:00
if args . tts_backend != " tortoise " :
raise f " Unsupported backend: { args . tts_backend } "
2023-03-07 02:45:22 +00:00
args . vocoder_model = vocoder_model
save_args_settings ( )
print ( f ' Stored vocoder model to settings: { vocoder_model } ' )
2023-02-20 00:21:16 +00:00
2023-03-07 02:45:22 +00:00
global tts
if not tts :
if tts_loading :
raise Exception ( " TTS is still initializing... " )
return
2023-02-28 15:36:06 +00:00
2023-03-07 03:55:35 +00:00
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
2023-03-07 02:45:22 +00:00
print ( f " Loading model: { vocoder_model } " )
tts . load_vocoder_model ( vocoder_model )
print ( f " Loaded model: { tts . vocoder_model } " )
2023-02-20 00:21:16 +00:00
do_gc ( )
2023-03-07 02:45:22 +00:00
return vocoder_model
2023-02-20 00:21:16 +00:00
2023-03-15 00:37:38 +00:00
def update_tokenizer ( tokenizer_json ) :
2023-03-31 03:26:00 +00:00
if args . tts_backend != " tortoise " :
raise f " Unsupported backend: { args . tts_backend } "
2023-03-15 00:37:38 +00:00
args . tokenizer_json = tokenizer_json
save_args_settings ( )
print ( f ' Stored tokenizer to settings: { tokenizer_json } ' )
global tts
if not tts :
if tts_loading :
raise Exception ( " TTS is still initializing... " )
return
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
2023-03-16 14:19:56 +00:00
print ( f " Loading tokenizer vocab: { tokenizer_json } " )
2023-03-15 00:37:38 +00:00
tts . load_tokenizer_json ( tokenizer_json )
2023-03-16 14:19:56 +00:00
print ( f " Loaded tokenizer vocab: { tts . tokenizer_json } " )
2023-03-15 00:37:38 +00:00
do_gc ( )
return vocoder_model
2023-02-20 00:21:16 +00:00
def load_voicefixer ( restart = False ) :
global voicefixer
if restart :
unload_voicefixer ( )
try :
print ( " Loading Voicefixer " )
from voicefixer import VoiceFixer
voicefixer = VoiceFixer ( )
2023-02-24 23:13:13 +00:00
print ( " Loaded Voicefixer " )
2023-02-20 00:21:16 +00:00
except Exception as e :
print ( f " Error occurred while tring to initialize voicefixer: { e } " )
2023-03-09 04:28:14 +00:00
if voicefixer :
del voicefixer
voicefixer = None
2023-02-20 00:21:16 +00:00
def unload_voicefixer ( ) :
global voicefixer
if voicefixer :
del voicefixer
voicefixer = None
2023-02-24 23:13:13 +00:00
print ( " Unloaded Voicefixer " )
2023-02-20 00:21:16 +00:00
do_gc ( )
2023-03-05 05:17:19 +00:00
def load_whisper_model ( language = None , model_name = None , progress = None ) :
2023-02-20 15:31:38 +00:00
global whisper_model
2023-03-22 19:24:53 +00:00
global whisper_vad
global whisper_diarize
2023-03-23 00:22:25 +00:00
global whisper_align_model
2023-03-11 16:40:34 +00:00
2023-03-06 05:21:33 +00:00
if args . whisper_backend not in WHISPER_BACKENDS :
raise Exception ( f " unavailable backend: { args . whisper_backend } " )
2023-03-05 05:17:19 +00:00
if not model_name :
model_name = args . whisper_model
2023-02-20 00:21:16 +00:00
else :
2023-03-05 05:17:19 +00:00
args . whisper_model = model_name
2023-02-24 23:13:13 +00:00
save_args_settings ( )
2023-02-20 00:21:16 +00:00
2023-03-05 05:17:19 +00:00
if language and f ' { model_name } . { language } ' in WHISPER_SPECIALIZED_MODELS :
model_name = f ' { model_name } . { language } '
print ( f " Loading specialized model for language: { language } " )
2023-05-04 23:40:33 +00:00
notify_progress ( f " Loading Whisper model: { model_name } " , progress = progress )
2023-03-05 17:54:36 +00:00
2023-03-06 05:21:33 +00:00
if args . whisper_backend == " openai/whisper " :
import whisper
2023-03-12 04:48:28 +00:00
try :
#is it possible for model to fit on vram but go oom later on while executing on data?
whisper_model = whisper . load_model ( model_name )
except :
2023-03-12 14:47:48 +00:00
print ( " Out of VRAM memory. falling back to loading Whisper on CPU. " )
2023-03-12 04:48:28 +00:00
whisper_model = whisper . load_model ( model_name , device = " cpu " )
2023-03-06 05:21:33 +00:00
elif args . whisper_backend == " lightmare/whispercpp " :
2023-02-27 19:20:06 +00:00
from whispercpp import Whisper
2023-03-05 05:17:19 +00:00
if not language :
language = ' auto '
2023-03-05 17:54:36 +00:00
b_lang = language . encode ( ' ascii ' )
whisper_model = Whisper ( model_name , models_dir = ' ./models/ ' , language = b_lang )
2023-03-22 19:24:53 +00:00
elif args . whisper_backend == " m-bain/whisperx " :
2023-05-06 10:45:17 +00:00
import whisper , whisperx
2023-03-22 19:24:53 +00:00
device = " cuda " if get_device_name ( ) == " cuda " else " cpu "
2023-05-06 10:45:17 +00:00
try :
whisper_model = whisperx . load_model ( model_name , device )
except Exception as e :
whisper_model = whisper . load_model ( model_name , device )
2023-03-22 19:24:53 +00:00
if not args . hf_token :
print ( " No huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model. " )
try :
from pyannote . audio import Inference , Pipeline
whisper_vad = Inference (
" pyannote/segmentation " ,
pre_aggregation_hook = lambda segmentation : segmentation ,
use_auth_token = args . hf_token ,
device = torch . device ( device ) ,
)
2023-03-23 01:52:26 +00:00
# whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token)
2023-03-23 00:22:25 +00:00
2023-03-22 19:24:53 +00:00
except Exception as e :
pass
2023-03-23 00:22:25 +00:00
whisper_align_model = whisperx . load_align_model ( model_name = " WAV2VEC2_ASR_LARGE_LV60K_960H " if language == " en " else None , language_code = language , device = device )
2023-03-05 05:17:19 +00:00
2023-02-24 23:13:13 +00:00
print ( " Loaded Whisper model " )
2023-02-20 00:21:16 +00:00
def unload_whisper ( ) :
global whisper_model
2023-03-23 00:22:25 +00:00
global whisper_vad
global whisper_diarize
global whisper_align_model
if whisper_vad :
del whisper_vad
whisper_vad = None
if whisper_diarize :
del whisper_diarize
whisper_diarize = None
if whisper_align_model :
del whisper_align_model
whisper_align_model = None
2023-02-20 00:21:16 +00:00
if whisper_model :
del whisper_model
whisper_model = None
2023-02-24 23:13:13 +00:00
print ( " Unloaded Whisper " )
2023-02-20 00:21:16 +00:00
2023-03-29 19:29:13 +00:00
do_gc ( )
# shamelessly borrowed from Voldy's Web UI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/extras.py#L74
def merge_models ( primary_model_name , secondary_model_name , alpha , progress = gr . Progress ( ) ) :
key_blacklist = [ ]
def weighted_sum ( theta0 , theta1 , alpha ) :
return ( ( 1 - alpha ) * theta0 ) + ( alpha * theta1 )
def read_model ( filename ) :
print ( f " Loading { filename } " )
return torch . load ( filename )
theta_func = weighted_sum
theta_0 = read_model ( primary_model_name )
theta_1 = read_model ( secondary_model_name )
2023-05-04 23:40:33 +00:00
for key in tqdm ( theta_0 . keys ( ) , desc = " Merging... " ) :
2023-03-29 19:29:13 +00:00
if key in key_blacklist :
print ( " Skipping ignored key: " , key )
continue
a = theta_0 [ key ]
b = theta_1 [ key ]
if a . dtype != torch . float32 and a . dtype != torch . float16 :
print ( " Skipping key: " , key , a . dtype )
continue
if b . dtype != torch . float32 and b . dtype != torch . float16 :
print ( " Skipping key: " , key , b . dtype )
continue
theta_0 [ key ] = theta_func ( a , b , alpha )
del theta_1
primary_basename = os . path . splitext ( os . path . basename ( primary_model_name ) ) [ 0 ]
secondary_basename = os . path . splitext ( os . path . basename ( secondary_model_name ) ) [ 0 ]
suffix = " {:.3f} " . format ( alpha )
output_path = f ' ./models/finetunes/ { primary_basename } _ { secondary_basename } _ { suffix } _merge.pth '
torch . save ( theta_0 , output_path )
message = f " Saved to { output_path } "
print ( message )
return message