big cleanup to make my life easier when i add more parameters

This commit is contained in:
mrq 2023-03-09 00:26:47 +00:00
parent 0ab091e7ff
commit 3f321fe664
5 changed files with 482 additions and 702 deletions

View File

@ -1,16 +1,18 @@
name: ${name} name: ${voice}
model: extensibletrainer model: extensibletrainer
scale: 1 scale: 1
gpu_ids: [0] # Superfluous, redundant, unnecessary, the way you launch the training script will set this gpu_ids: [0] # Superfluous, redundant, unnecessary, the way you launch the training script will set this
start_step: 0 start_step: 0
checkpointing_enabled: true checkpointing_enabled: true
fp16: ${float16} fp16: ${half_p}
bitsandbytes: ${bitsandbytes}
gpus: ${gpus}
wandb: false wandb: false
use_tb_logger: true use_tb_logger: true
datasets: datasets:
train: train:
name: ${dataset_name} name: training
n_workers: ${workers} n_workers: ${workers}
batch_size: ${batch_size} batch_size: ${batch_size}
mode: paired_voice_audio mode: paired_voice_audio
@ -27,7 +29,7 @@ datasets:
tokenizer_vocab: ./models/tortoise/bpe_lowercase_asr_256.json tokenizer_vocab: ./models/tortoise/bpe_lowercase_asr_256.json
load_aligned_codes: False load_aligned_codes: False
val: # I really do not care about validation right now val: # I really do not care about validation right now
name: ${validation_name} name: validation
n_workers: ${workers} n_workers: ${workers}
batch_size: ${validation_batch_size} batch_size: ${validation_batch_size}
mode: paired_voice_audio mode: paired_voice_audio
@ -114,8 +116,8 @@ networks:
#only_alignment_head: False # uv3/4 #only_alignment_head: False # uv3/4
path: path:
${pretrain_model_gpt}
strict_load: true strict_load: true
${source_model}
${resume_state} ${resume_state}
train: train:

View File

@ -1,35 +0,0 @@
import os
import sys
indir = f'./training/{sys.argv[1]}/'
cap = int(sys.argv[2])
if not os.path.isdir(indir):
raise Exception(f"Invalid directory: {indir}")
if not os.path.exists(f'{indir}/train.txt'):
raise Exception(f"Missing dataset: {indir}/train.txt")
with open(f'{indir}/train.txt', 'r', encoding="utf-8") as f:
lines = f.readlines()
validation = []
training = []
for line in lines:
split = line.split("|")
filename = split[0]
text = split[1]
if len(text) < cap:
validation.append(line.strip())
else:
training.append(line.strip())
with open(f'{indir}/train_culled.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(training))
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(validation))
print(f"Culled {len(validation)} lines")

View File

@ -46,6 +46,7 @@ sys.path.insert(0, './dlas/')
# without kludge, it'll have to be accessible as `codes` and not `dlas` # without kludge, it'll have to be accessible as `codes` and not `dlas`
import torch import torch
import datetime
from codes import train as tr from codes import train as tr
from utils import util, options as option from utils import util, options as option
@ -71,7 +72,7 @@ def train(yaml, launcher='none'):
print('Disabled distributed training.') print('Disabled distributed training.')
else: else:
opt['dist'] = True opt['dist'] = True
tr.init_dist('nccl') tr.init_dist('nccl', timeout=datetime.timedelta(seconds=5*60))
trainer.world_size = torch.distributed.get_world_size() trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank() trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank()) torch.cuda.set_device(torch.distributed.get_rank())

View File

@ -34,7 +34,7 @@ from datetime import timedelta
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
from tortoise.utils.text import split_and_recombine_text from tortoise.utils.text import split_and_recombine_text
from tortoise.utils.device import get_device_name, set_device_name from tortoise.utils.device import get_device_name, set_device_name, get_device_count
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
@ -44,6 +44,8 @@ WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band'] VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
GENERATE_SETTINGS_ARGS = None
EPOCH_SCHEDULE = [ 9, 18, 25, 33 ] EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
args = None args = None
@ -56,30 +58,17 @@ training_state = None
current_voice = None current_voice = None
def generate( def generate(**kwargs):
text, parameters = {}
delimiter, parameters.update(kwargs)
emotion,
prompt, voice = parameters['voice']
voice, progress = parameters['progress'] if 'progress' in parameters else None
mic_audio, if parameters['seed'] == 0:
voice_latents_chunks, parameters['seed'] = None
seed,
candidates, usedSeed = parameters['seed']
num_autoregressive_samples,
diffusion_iterations,
temperature,
diffusion_sampler,
breathing_room,
cvvp_weight,
top_p,
diffusion_temperature,
length_penalty,
repetition_penalty,
cond_free_k,
experimental_checkboxes,
progress=None
):
global args global args
global tts global tts
@ -90,6 +79,8 @@ def generate(
# should check if it's loading or unloaded, and load it if it's unloaded # should check if it's loading or unloaded, and load it if it's unloaded
if tts_loading: if tts_loading:
raise Exception("TTS is still initializing...") raise Exception("TTS is still initializing...")
if progress is not None:
progress(0, "Initializing TTS...")
load_tts() load_tts()
if hasattr(tts, "loading") and tts.loading: if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...") raise Exception("TTS is still initializing...")
@ -100,9 +91,6 @@ def generate(
conditioning_latents =None conditioning_latents =None
sample_voice = None sample_voice = None
if seed == 0:
seed = None
voice_cache = {} voice_cache = {}
def fetch_voice( voice ): def fetch_voice( voice ):
print(f"Loading voice: {voice} with model {tts.autoregressive_model_hash[:8]}") print(f"Loading voice: {voice} with model {tts.autoregressive_model_hash[:8]}")
@ -112,9 +100,9 @@ def generate(
sample_voice = None sample_voice = None
if voice == "microphone": if voice == "microphone":
if mic_audio is None: if parameters['mic_audio'] is None:
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input") raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
voice_samples, conditioning_latents = [load_audio(mic_audio, tts.input_sample_rate)], None voice_samples, conditioning_latents = [load_audio(parameters['mic_audio'], tts.input_sample_rate)], None
elif voice == "random": elif voice == "random":
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents() voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
else: else:
@ -125,7 +113,7 @@ def generate(
if voice_samples and len(voice_samples) > 0: if voice_samples and len(voice_samples) > 0:
if conditioning_latents is None: if conditioning_latents is None:
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks) conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=parameters['voice_latents_chunks'])
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu() sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
voice_samples = None voice_samples = None
@ -135,30 +123,30 @@ def generate(
def get_settings( override=None ): def get_settings( override=None ):
settings = { settings = {
'temperature': float(temperature), 'temperature': float(parameters['temperature']),
'top_p': float(top_p), 'top_p': float(parameters['top_p']),
'diffusion_temperature': float(diffusion_temperature), 'diffusion_temperature': float(parameters['diffusion_temperature']),
'length_penalty': float(length_penalty), 'length_penalty': float(parameters['length_penalty']),
'repetition_penalty': float(repetition_penalty), 'repetition_penalty': float(parameters['repetition_penalty']),
'cond_free_k': float(cond_free_k), 'cond_free_k': float(parameters['cond_free_k']),
'num_autoregressive_samples': num_autoregressive_samples, 'num_autoregressive_samples': parameters['num_autoregressive_samples'],
'sample_batch_size': args.sample_batch_size, 'sample_batch_size': args.sample_batch_size,
'diffusion_iterations': diffusion_iterations, 'diffusion_iterations': parameters['diffusion_iterations'],
'voice_samples': None, 'voice_samples': None,
'conditioning_latents': None, 'conditioning_latents': None,
'use_deterministic_seed': seed, 'use_deterministic_seed': parameters['seed'],
'return_deterministic_state': True, 'return_deterministic_state': True,
'k': candidates, 'k': parameters['candidates'],
'diffusion_sampler': diffusion_sampler, 'diffusion_sampler': parameters['diffusion_sampler'],
'breathing_room': breathing_room, 'breathing_room': parameters['breathing_room'],
'progress': progress, 'progress': parameters['progress'],
'half_p': "Half Precision" in experimental_checkboxes, 'half_p': "Half Precision" in parameters['experimentals'],
'cond_free': "Conditioning-Free" in experimental_checkboxes, 'cond_free': "Conditioning-Free" in parameters['experimentals'],
'cvvp_amount': cvvp_weight, 'cvvp_amount': parameters['cvvp_weight'],
'autoregressive_model': args.autoregressive_model, 'autoregressive_model': args.autoregressive_model,
} }
@ -182,11 +170,11 @@ def generate(
# clamp it down for the insane users who want this # clamp it down for the insane users who want this
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants # it would be wiser to enforce the sample size to the batch size, but this is what the user wants
sample_batch_size = args.sample_batch_size settings['sample_batch_size'] = args.sample_batch_size
if not sample_batch_size: if not settings['sample_batch_size']:
sample_batch_size = tts.autoregressive_batch_size settings['sample_batch_size'] = tts.autoregressive_batch_size
if num_autoregressive_samples < sample_batch_size: if settings['num_autoregressive_samples'] < settings['sample_batch_size']:
settings['sample_batch_size'] = num_autoregressive_samples settings['sample_batch_size'] = settings['num_autoregressive_samples']
if settings['conditioning_latents'] is not None and len(settings['conditioning_latents']) == 2 and settings['cvvp_amount'] > 0: if settings['conditioning_latents'] is not None and len(settings['conditioning_latents']) == 2 and settings['cvvp_amount'] > 0:
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents with 'Slimmer voice latents' unchecked.") print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents with 'Slimmer voice latents' unchecked.")
@ -194,15 +182,15 @@ def generate(
return settings return settings
if not delimiter: if not parameters['delimiter']:
delimiter = "\n" parameters['delimiter'] = "\n"
elif delimiter == "\\n": elif parameters['delimiter'] == "\\n":
delimiter = "\n" parameters['delimiter'] = "\n"
if delimiter and delimiter != "" and delimiter in text: if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
texts = text.split(delimiter) texts = parameters['text'].split(parameters['delimiter'])
else: else:
texts = split_and_recombine_text(text) texts = split_and_recombine_text(parameters['text'])
full_start_time = time.time() full_start_time = time.time()
@ -248,37 +236,23 @@ def generate(
name = f"{name}_combined" name = f"{name}_combined"
elif len(texts) > 1: elif len(texts) > 1:
name = f"{name}_{line}" name = f"{name}_{line}"
if candidates > 1: if parameters['candidates'] > 1:
name = f"{name}_{candidate}" name = f"{name}_{candidate}"
return name return name
def get_info( voice, settings = None, latents = True ): def get_info( voice, settings = None, latents = True ):
info = { info = {}
'text': text, info.update(parameters)
'delimiter': '\\n' if delimiter and delimiter == "\n" else delimiter, info['time'] = time.time()-full_start_time,
'emotion': emotion,
'prompt': prompt,
'voice': voice,
'seed': seed,
'candidates': candidates,
'num_autoregressive_samples': num_autoregressive_samples,
'diffusion_iterations': diffusion_iterations,
'temperature': temperature,
'diffusion_sampler': diffusion_sampler,
'breathing_room': breathing_room,
'cvvp_weight': cvvp_weight,
'top_p': top_p,
'diffusion_temperature': diffusion_temperature,
'length_penalty': length_penalty,
'repetition_penalty': repetition_penalty,
'cond_free_k': cond_free_k,
'experimentals': experimental_checkboxes,
'time': time.time()-full_start_time,
'datetime': datetime.now().isoformat(), info['datetime'] = datetime.now().isoformat(),
'model': tts.autoregressive_model_path, info['model'] = tts.autoregressive_model_path,
'model_hash': tts.autoregressive_model_hash info['model_hash'] = tts.autoregressive_model_hash
} info['progress'] = None
del info['progress']
if info['delimiter'] == "\n":
info['delimiter'] = "\\n"
if settings is not None: if settings is not None:
for k in settings: for k in settings:
@ -319,11 +293,11 @@ def generate(
return info return info
for line, cut_text in enumerate(texts): for line, cut_text in enumerate(texts):
if emotion == "Custom": if parameters['emotion'] == "Custom":
if prompt and prompt.strip() != "": if parameters['prompt'] and parameters['prompt'].strip() != "":
cut_text = f"[{prompt},] {cut_text}" cut_text = f"[{parameters['prompt']},] {cut_text}"
elif emotion != "None" and emotion: elif parameters['emotion'] != "None" and parameters['emotion']:
cut_text = f"[I am really {emotion.lower()},] {cut_text}" cut_text = f"[I am really {parameters['emotion'].lower()},] {cut_text}"
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
print(f"{progress.msg_prefix} Generating line: {cut_text}") print(f"{progress.msg_prefix} Generating line: {cut_text}")
@ -343,10 +317,10 @@ def generate(
settings = get_settings( override=override ) settings = get_settings( override=override )
gen, additionals = tts.tts(cut_text, **settings ) gen, additionals = tts.tts(cut_text, **settings )
seed = additionals[0] parameters['seed'] = additionals[0]
run_time = time.time()-start_time run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds") print(f"Generating line took {run_time} seconds")
if not isinstance(gen, list): if not isinstance(gen, list):
gen = [gen] gen = [gen]
@ -382,7 +356,7 @@ def generate(
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
output_voices = [] output_voices = []
for candidate in range(candidates): for candidate in range(parameters['candidates']):
if len(texts) > 1: if len(texts) > 1:
audio_clips = [] audio_clips = []
for line in range(len(texts)): for line in range(len(texts)):
@ -466,7 +440,7 @@ def generate(
info = get_info(voice=voice, latents=False) info = get_info(voice=voice, latents=False)
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n") print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
info['seed'] = seed info['seed'] = usedSeed
if 'latents' in info: if 'latents' in info:
del info['latents'] del info['latents']
@ -475,7 +449,7 @@ def generate(
f.write(json.dumps(info, indent='\t') ) f.write(json.dumps(info, indent='\t') )
stats = [ stats = [
[ seed, "{:.3f}".format(info['time']) ] [ parameters['seed'], "{:.3f}".format(info['time']) ]
] ]
return ( return (
@ -609,14 +583,16 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
# superfluous, but it cleans up some things # superfluous, but it cleans up some things
class TrainingState(): class TrainingState():
def __init__(self, config_path, keep_x_past_checkpoints=0, start=True, gpus=1): def __init__(self, config_path, keep_x_past_checkpoints=0, start=True):
# parse config to get its iteration # parse config to get its iteration
with open(config_path, 'r') as file: with open(config_path, 'r') as file:
self.config = yaml.safe_load(file) self.config = yaml.safe_load(file)
gpus = self.config["gpus"]
self.killed = False self.killed = False
self.dataset_dir = f"./training/{self.config['name']}/" self.dataset_dir = f"./training/{self.config['name']}/finetune/"
self.batch_size = self.config['datasets']['train']['batch_size'] self.batch_size = self.config['datasets']['train']['batch_size']
self.dataset_path = self.config['datasets']['train']['path'] self.dataset_path = self.config['datasets']['train']['path']
with open(self.dataset_path, 'r', encoding="utf-8") as f: with open(self.dataset_path, 'r', encoding="utf-8") as f:
@ -996,7 +972,7 @@ except Exception as e:
print(e) print(e)
pass pass
def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)): def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)):
global training_state global training_state
if training_state and training_state.process: if training_state and training_state.process:
return "Training already in progress" return "Training already in progress"
@ -1008,26 +984,11 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0,
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process # I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
torch.multiprocessing.freeze_support() torch.multiprocessing.freeze_support()
# edit any gpu-count-specific variables
with open(config_path, 'r', encoding="utf-8") as f:
yaml_string = f.read()
edited = False
if gpus > 1:
yaml_string = yaml_string.replace(" adamw ", " adamw_zero ")
edited = True
else:
yaml_string = yaml_string.replace(" adamw_zero ", " adamw ")
edited = True
if edited:
print(f'Modified YAML config')
with open(config_path, 'w', encoding="utf-8") as f:
f.write(yaml_string)
unload_tts() unload_tts()
unload_whisper() unload_whisper()
unload_voicefixer() unload_voicefixer()
training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints, gpus=gpus) training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints)
for line in iter(training_state.process.stdout.readline, ""): for line in iter(training_state.process.stdout.readline, ""):
if training_state.killed: if training_state.killed:
@ -1169,7 +1130,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
if whisper_model is None: if whisper_model is None:
load_whisper_model(language=language) load_whisper_model(language=language)
os.makedirs(outdir, exist_ok=True) os.makedirs(f'{outdir}/audio/', exist_ok=True)
results = {} results = {}
transcription = [] transcription = []
@ -1216,10 +1177,10 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
print(f"Error with {sliced_name}, skipping...") print(f"Error with {sliced_name}, skipping...")
continue continue
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate) torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
idx = idx + 1 idx = idx + 1
line = f"{sliced_name}|{segment['text'].strip()}" line = f"audio/{sliced_name}|{segment['text'].strip()}"
transcription.append(line) transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'\n{line}') f.write(f'\n{line}')
@ -1283,125 +1244,142 @@ def calc_iterations( epochs, lines, batch_size ):
def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ): def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
return [int(iterations * d) for d in schedule] return [int(iterations * d) for d in schedule]
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ): def optimize_training_settings( **kwargs ):
name = f"{voice}-finetune" messages = []
dataset_path = f"./training/{voice}/train.txt" settings = {}
settings.update(kwargs)
dataset_path = f"./training/{settings['voice']}/train.txt"
with open(dataset_path, 'r', encoding="utf-8") as f: with open(dataset_path, 'r', encoding="utf-8") as f:
lines = len(f.readlines()) lines = len(f.readlines())
messages = [] if settings['batch_size'] > lines:
settings['batch_size'] = lines
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {settings['batch_size']}")
if batch_size > lines: if settings['batch_size'] % lines != 0:
batch_size = lines nearest_slice = int(lines / settings['batch_size']) + 1
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}") settings['batch_size'] = int(lines / nearest_slice)
messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {settings['batch_size']} ({nearest_slice} steps per epoch)")
if batch_size % lines != 0:
nearest_slice = int(lines / batch_size) + 1
batch_size = int(lines / nearest_slice)
messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)")
if gradient_accumulation_size == 0: if settings['gradient_accumulation_size'] == 0:
gradient_accumulation_size = 1 settings['gradient_accumulation_size'] = 1
if batch_size / gradient_accumulation_size < 2: if settings['batch_size'] / settings['gradient_accumulation_size'] < 2:
gradient_accumulation_size = int(batch_size / 2) settings['gradient_accumulation_size'] = int(settings['batch_size'] / 2)
if gradient_accumulation_size == 0: if settings['gradient_accumulation_size'] == 0:
gradient_accumulation_size = 1 settings['gradient_accumulation_size'] = 1
messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}") messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {settings['gradient_accumulation_size']}")
elif batch_size % gradient_accumulation_size != 0: elif settings['batch_size'] % settings['gradient_accumulation_size'] != 0:
gradient_accumulation_size = int(batch_size / gradient_accumulation_size) settings['gradient_accumulation_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
if gradient_accumulation_size == 0: if settings['gradient_accumulation_size'] == 0:
gradient_accumulation_size = 1 settings['gradient_accumulation_size'] = 1
messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}") messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {settings['gradient_accumulation_size']}")
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size) iterations = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
if epochs < print_rate: if settings['epochs'] < settings['print_rate']:
print_rate = epochs settings['print_rate'] = settings['epochs']
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}") messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {settings['print_rate']}")
if epochs < save_rate: if settings['epochs'] < settings['save_rate']:
save_rate = epochs settings['save_rate'] = settings['epochs']
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}") messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {settings['save_rate']}")
if epochs < validation_rate: if settings['epochs'] < settings['validation_rate']:
validation_rate = epochs settings['validation_rate'] = settings['epochs']
messages.append(f"Validation rate is too small for the given iteration step, clamping validation rate to: {validation_rate}") messages.append(f"Validation rate is too small for the given iteration step, clamping validation rate to: {settings['validation_rate']}")
if resume_path and not os.path.exists(resume_path): if settings['resume_state'] and not os.path.exists(settings['resume_state']):
resume_path = None settings['resume_state'] = None
messages.append("Resume path specified, but does not exist. Disabling...") messages.append("Resume path specified, but does not exist. Disabling...")
if bnb: if settings['bitsandbytes']:
messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !") messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !")
if half_p: if settings['half_p']:
if bnb: if settings['bitsandbytes']:
half_p = False settings['half_p'] = False
messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...") messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
else: else:
messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !") messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !")
if not os.path.exists(get_halfp_model_path()): if not os.path.exists(get_halfp_model_path()):
convert_to_halfp() convert_to_halfp()
messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)") messages.append(f"For {settings['epochs']} epochs with {lines} lines in batches of {settings['batch_size']}, iterating for {iterations} steps ({int(iterations / settings['epochs'])} steps per epoch)")
return ( return settings, messages
learning_rate,
text_ce_lr_weight,
learning_rate_schedule,
batch_size,
gradient_accumulation_size,
print_rate,
save_rate,
validation_rate,
resume_path,
messages
)
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_scheme=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, validation_batch_size=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ): def save_training_settings( **kwargs ):
if not source_model: messages = []
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth" settings = {}
settings.update(kwargs)
settings = { settings['dataset_path'] = f"./training/{settings['voice']}/train.txt"
"iterations": iterations if iterations else 500, settings['validation_path'] = f"./training/{settings['voice']}/validation.txt"
"batch_size": batch_size if batch_size else 64,
"learning_rate": learning_rate if learning_rate else 1e-5,
"gradient_accumulation_size": gradient_accumulation_size if gradient_accumulation_size else 4,
"print_rate": print_rate if print_rate else 1,
"save_rate": save_rate if save_rate else 50,
"name": name if name else "finetune",
"dataset_name": dataset_name if dataset_name else "finetune",
"dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt",
"validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
'validation_rate': validation_rate if validation_rate else iterations,
"validation_batch_size": validation_batch_size if validation_batch_size else batch_size,
'validation_enabled': "true",
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01, with open(settings['dataset_path'], 'r', encoding="utf-8") as f:
lines = len(f.readlines())
'resume_state': f"resume_state: '{resume_path}'", if not settings['source_model'] or settings['source_model'] == "auto":
'pretrain_model_gpt': f"pretrain_model_gpt: '{source_model}'", settings['source_model'] = f"./models/tortoise/autoregressive{'_half' if settings['half_p'] else ''}.pth"
'float16': 'true' if half_p else 'false', if settings['half_p']:
'bitsandbytes': 'true' if bnb else 'false', if not os.path.exists(get_halfp_model_path()):
convert_to_halfp()
'workers': workers if workers else 2, settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
} messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
settings['print_rate'] = int(settings['print_rate'] * settings['iterations'] / settings['epochs'])
settings['save_rate'] = int(settings['save_rate'] * settings['iterations'] / settings['epochs'])
settings['validation_rate'] = int(settings['validation_rate'] * settings['iterations'] / settings['epochs'])
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
if settings['iterations'] % settings['save_rate'] != 0:
adjustment = int(settings['iterations'] / settings['save_rate']) * settings['save_rate']
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {settings['iterations']} => {adjustment}")
settings['iterations'] = adjustment
if not os.path.exists(settings['validation_path']):
settings['validation_enabled'] = False
messages.append("Validation not found, disabling validation...")
elif settings['validation_batch_size'] == 0:
settings['validation_enabled'] = False
messages.append("Validation batch size == 0, disabling validation...")
else:
settings['validation_enabled'] = True
with open(settings['validation_path'], 'r', encoding="utf-8") as f:
validation_lines = len(f.readlines())
if validation_lines < settings['validation_batch_size']:
settings['validation_batch_size'] = validation_lines
messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}")
if settings['gpus'] > get_device_count():
settings['gpus'] = get_device_count()
LEARNING_RATE_SCHEMES = ["MultiStepLR", "CosineAnnealingLR_Restart"] LEARNING_RATE_SCHEMES = ["MultiStepLR", "CosineAnnealingLR_Restart"]
if learning_rate_scheme not in LEARNING_RATE_SCHEMES: if 'learning_rate_scheme' not in settings or settings['learning_rate_scheme'] not in LEARNING_RATE_SCHEMES:
learning_rate_scheme = LEARNING_RATE_SCHEMES[0] settings['learning_rate_scheme'] = LEARNING_RATE_SCHEMES[0]
learning_rate_schema = [f"default_lr_scheme: {learning_rate_scheme}"] learning_rate_schema = [f"default_lr_scheme: {settings['learning_rate_scheme']}"]
if learning_rate_scheme == "MultiStepLR": if settings['learning_rate_scheme'] == "MultiStepLR":
learning_rate_schema.append(f" gen_lr_steps: {learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE}") if not settings['learning_rate_schedule']:
settings['learning_rate_schedule'] = EPOCH_SCHEDULE
elif isinstance(settings['learning_rate_schedule'],str):
settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule'])
settings['learning_rate_schedule'] = schedule_learning_rate( settings['iterations'] / settings['epochs'], settings['learning_rate_schedule'] )
learning_rate_schema.append(f" gen_lr_steps: {settings['learning_rate_schedule']}")
learning_rate_schema.append(f" lr_gamma: 0.5") learning_rate_schema.append(f" lr_gamma: 0.5")
elif learning_rate_scheme == "CosineAnnealingLR_Restart": elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart":
learning_rate_schema.append(f" T_period: [120000, 120000, 120000]") learning_rate_schema.append(f" T_period: [120000, 120000, 120000]")
learning_rate_schema.append(f" warmup: 10000") learning_rate_schema.append(f" warmup: 10000")
learning_rate_schema.append(f" eta_min: .01") learning_rate_schema.append(f" eta_min: .01")
@ -1409,23 +1387,26 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
learning_rate_schema.append(f" restart_weights: [.5, .25]") learning_rate_schema.append(f" restart_weights: [.5, .25]")
settings['learning_rate_scheme'] = "\n".join(learning_rate_schema) settings['learning_rate_scheme'] = "\n".join(learning_rate_schema)
if resume_path: """
if resume_state:
settings['pretrain_model_gpt'] = f"# {settings['pretrain_model_gpt']}" settings['pretrain_model_gpt'] = f"# {settings['pretrain_model_gpt']}"
else: else:
settings['resume_state'] = f"# resume_state: './training/{name if name else 'finetune'}/training_state/#.state'" settings['resume_state'] = f"# resume_state: './training/{voice}/training_state/#.state'"
# also disable validation if it doesn't make sense to do it # also disable validation if it doesn't make sense to do it
if settings['dataset_path'] == settings['validation_path'] or not os.path.exists(settings['validation_path']): if settings['dataset_path'] == settings['validation_path'] or not os.path.exists(settings['validation_path']):
settings['validation_enabled'] = 'false' settings['validation_enabled'] = 'false'
"""
outjson = f'./training/{settings["voice"]}/train.json'
with open(outjson, 'w', encoding="utf-8") as f:
f.write(json.dumps(settings, indent='\t') )
if settings['resume_state']:
if half_p: settings['source_model'] = f"# pretrain_model_gpt: {settings['source_model']}"
if not os.path.exists(get_halfp_model_path()): settings['resume_state'] = f"resume_state: {settings['resume_state']}'"
convert_to_halfp() else:
settings['source_model'] = f"pretrain_model_gpt: {settings['source_model']}"
if not output_name: settings['resume_state'] = f"# resume_state: {settings['resume_state']}'"
output_name = f'{settings["name"]}.yaml'
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f: with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
yaml = f.read() yaml = f.read()
@ -1436,11 +1417,13 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
continue continue
yaml = yaml.replace(f"${{{k}}}", str(settings[k])) yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
outfile = f'./training/{output_name}' outyaml = f'./training/{settings["voice"]}/train.yaml'
with open(outfile, 'w', encoding="utf-8") as f: with open(outyaml, 'w', encoding="utf-8") as f:
f.write(yaml) f.write(yaml)
return f"Training settings saved to: {outfile}" messages.append(f"Saved training output to: {outyaml}")
return settings, messages
def import_voices(files, saveAs=None, progress=None): def import_voices(files, saveAs=None, progress=None):
global args global args
@ -1524,10 +1507,10 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
additionals = sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ]) additionals = sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ])
found = [] found = []
for training in os.listdir(f'./training/'): for training in os.listdir(f'./training/'):
if not os.path.isdir(f'./training/{training}/') or not os.path.isdir(f'./training/{training}/models/'): if not os.path.isdir(f'./training/{training}/') or not os.path.isdir(f'./training/{training}/finetunes/') or not os.path.isdir(f'./training/{training}/finetunes/models/'):
continue continue
models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ]) models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/finetunes/models/') if d[-8:] == "_gpt.pth" ])
found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ] found = found + [ f'./training/{training}/finetunes/models/{d}_gpt.pth' for d in models ]
if len(found) > 0 or len(additionals) > 0: if len(found) > 0 or len(additionals) > 0:
base = ["auto"] + base base = ["auto"] + base
@ -1545,10 +1528,10 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
return res return res
def get_dataset_list(dir="./training/"): def get_dataset_list(dir="./training/"):
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ]) return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ])
def get_training_list(dir="./training/"): def get_training_list(dir="./training/"):
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
def do_gc(): def do_gc():
gc.collect() gc.collect()
@ -1734,35 +1717,38 @@ def setup_args():
return args return args
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, autocalculate_voice_chunk_duration_size, output_volume, autoregressive_model, vocoder_model, whisper_backend, whisper_model, training_default_halfp, training_default_bnb ): def update_args( **kwargs ):
global args global args
args.listen = listen settings = {}
args.share = share settings.update(kwargs)
args.check_for_updates = check_for_updates
args.models_from_local_only = models_from_local_only
args.low_vram = low_vram
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
args.defer_tts_load = defer_tts_load
args.prune_nonfinal_outputs = prune_nonfinal_outputs
args.device_override = device_override
args.sample_batch_size = sample_batch_size
args.embed_output_metadata = embed_output_metadata
args.latents_lean_and_mean = latents_lean_and_mean
args.voice_fixer = voice_fixer
args.voice_fixer_use_cuda = voice_fixer_use_cuda
args.concurrency_count = concurrency_count
args.output_sample_rate = 44000
args.autocalculate_voice_chunk_duration_size = autocalculate_voice_chunk_duration_size
args.output_volume = output_volume
args.autoregressive_model = autoregressive_model
args.vocoder_model = vocoder_model
args.whisper_backend = whisper_backend
args.whisper_model = whisper_model
args.training_default_halfp = training_default_halfp args.listen = settings['listen']
args.training_default_bnb = training_default_bnb args.share = settings['share']
args.check_for_updates = settings['check_for_updates']
args.models_from_local_only = settings['models_from_local_only']
args.low_vram = settings['low_vram']
args.force_cpu_for_conditioning_latents = settings['force_cpu_for_conditioning_latents']
args.defer_tts_load = settings['defer_tts_load']
args.prune_nonfinal_outputs = settings['prune_nonfinal_outputs']
args.device_override = settings['device_override']
args.sample_batch_size = settings['sample_batch_size']
args.embed_output_metadata = settings['embed_output_metadata']
args.latents_lean_and_mean = settings['latents_lean_and_mean']
args.voice_fixer = settings['voice_fixer']
args.voice_fixer_use_cuda = settings['voice_fixer_use_cuda']
args.concurrency_count = settings['concurrency_count']
args.output_sample_rate = 44000
args.autocalculate_voice_chunk_duration_size = settings['autocalculate_voice_chunk_duration_size']
args.output_volume = settings['output_volume']
args.autoregressive_model = settings['autoregressive_model']
args.vocoder_model = settings['vocoder_model']
args.whisper_backend = settings['whisper_backend']
args.whisper_model = settings['whisper_model']
args.training_default_halfp = settings['training_default_halfp']
args.training_default_bnb = settings['training_default_bnb']
save_args_settings() save_args_settings()
@ -1801,37 +1787,49 @@ def save_args_settings():
with open(f'./config/exec.json', 'w', encoding="utf-8") as f: with open(f'./config/exec.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(settings, indent='\t') ) f.write(json.dumps(settings, indent='\t') )
# super kludgy )`;
def set_generate_settings_arg_order(args):
global GENERATE_SETTINGS_ARGS
GENERATE_SETTINGS_ARGS = args
def import_generate_settings(file="./config/generate.json"): def import_generate_settings(file="./config/generate.json"):
global GENERATE_SETTINGS_ARGS
defaults = {
'text': None,
'delimiter': None,
'emotion': None,
'prompt': None,
'voice': None,
'mic_audio': None,
'voice_latents_chunks': None,
'candidates': None,
'seed': None,
'num_autoregressive_samples': 16,
'diffusion_iterations': 30,
'temperature': 0.8,
'diffusion_sampler': "DDIM",
'breathing_room': 8 ,
'cvvp_weight': 0.0,
'top_p': 0.8,
'diffusion_temperature': 1.0,
'length_penalty': 1.0,
'repetition_penalty': 2.0,
'cond_free_k': 2.0,
'experimentals': None,
}
settings, _ = read_generate_settings(file, read_latents=False) settings, _ = read_generate_settings(file, read_latents=False)
if settings is None: res = []
return None if GENERATE_SETTINGS_ARGS is not None:
for k in GENERATE_SETTINGS_ARGS:
res.append(defaults[k] if not settings or settings[k] is None else settings[k])
else:
for k in defaults:
res.append(defaults[k] if not settings or settings[k] is None else settings[k])
return ( return tuple(res)
None if 'text' not in settings else settings['text'],
None if 'delimiter' not in settings else settings['delimiter'],
None if 'emotion' not in settings else settings['emotion'],
None if 'prompt' not in settings else settings['prompt'],
None if 'voice' not in settings else settings['voice'],
None,
None,
None if 'seed' not in settings else settings['seed'],
None if 'candidates' not in settings else settings['candidates'],
None if 'num_autoregressive_samples' not in settings else settings['num_autoregressive_samples'],
None if 'diffusion_iterations' not in settings else settings['diffusion_iterations'],
0.8 if 'temperature' not in settings else settings['temperature'],
"DDIM" if 'diffusion_sampler' not in settings else settings['diffusion_sampler'],
8 if 'breathing_room' not in settings else settings['breathing_room'],
0.0 if 'cvvp_weight' not in settings else settings['cvvp_weight'],
0.8 if 'top_p' not in settings else settings['top_p'],
1.0 if 'diffusion_temperature' not in settings else settings['diffusion_temperature'],
1.0 if 'length_penalty' not in settings else settings['length_penalty'],
2.0 if 'repetition_penalty' not in settings else settings['repetition_penalty'],
2.0 if 'cond_free_k' not in settings else settings['cond_free_k'],
None if 'experimentals' not in settings else settings['experimentals'],
)
def reset_generation_settings(): def reset_generation_settings():
@ -1955,10 +1953,10 @@ def deduce_autoregressive_model(voice=None):
voice = get_current_voice() voice = get_current_voice()
if voice: if voice:
dir = f'./training/{voice}-finetune/models/' if os.path.exists(f'./models/finetunes/{voice}.pth'):
if os.path.exists(f'./training/finetunes/{voice}.pth'): return f'./models/finetunes/{voice}.pth'
return f'./training/finetunes/{voice}.pth'
dir = f'./training/{voice}/finetune/models/'
if os.path.isdir(dir): if os.path.isdir(dir):
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ]) counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
names = [ f'{dir}/{d}_gpt.pth' for d in counts ] names = [ f'{dir}/{d}_gpt.pth' for d in counts ]

View File

@ -4,6 +4,7 @@ import time
import json import json
import base64 import base64
import re import re
import inspect
import urllib.request import urllib.request
import torch import torch
@ -22,7 +23,38 @@ from utils import *
args = setup_args() args = setup_args()
def run_generation( GENERATE_SETTINGS = {}
TRANSCRIBE_SETTINGS = {}
EXEC_SETTINGS = {}
TRAINING_SETTINGS = {}
PRESETS = {
'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False},
'Fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 80},
'Standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200},
'High Quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400},
}
HISTORY_HEADERS = {
"Name": "",
"Samples": "num_autoregressive_samples",
"Iterations": "diffusion_iterations",
"Temp.": "temperature",
"Sampler": "diffusion_sampler",
"CVVP": "cvvp_weight",
"Top P": "top_p",
"Diff. Temp.": "diffusion_temperature",
"Len Pen": "length_penalty",
"Rep Pen": "repetition_penalty",
"Cond-Free K": "cond_free_k",
"Time": "time",
"Datetime": "datetime",
"Model": "model",
"Model Hash": "model_hash",
}
# can't use *args OR **kwargs if I want to retain the ability to use progress
def generate_proxy(
text, text,
delimiter, delimiter,
emotion, emotion,
@ -30,8 +62,8 @@ def run_generation(
voice, voice,
mic_audio, mic_audio,
voice_latents_chunks, voice_latents_chunks,
seed,
candidates, candidates,
seed,
num_autoregressive_samples, num_autoregressive_samples,
diffusion_iterations, diffusion_iterations,
temperature, temperature,
@ -43,47 +75,20 @@ def run_generation(
length_penalty, length_penalty,
repetition_penalty, repetition_penalty,
cond_free_k, cond_free_k,
experimental_checkboxes, experimentals,
progress=gr.Progress(track_tqdm=True) progress=gr.Progress(track_tqdm=True)
): ):
if not text: kwargs = locals()
raise gr.Error("Please provide text.")
if not voice:
raise gr.Error("Please provide a voice.")
try: try:
sample, outputs, stats = generate( sample, outputs, stats = generate(**kwargs)
text=text,
delimiter=delimiter,
emotion=emotion,
prompt=prompt,
voice=voice,
mic_audio=mic_audio,
voice_latents_chunks=voice_latents_chunks,
seed=seed,
candidates=candidates,
num_autoregressive_samples=num_autoregressive_samples,
diffusion_iterations=diffusion_iterations,
temperature=temperature,
diffusion_sampler=diffusion_sampler,
breathing_room=breathing_room,
cvvp_weight=cvvp_weight,
top_p=top_p,
diffusion_temperature=diffusion_temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
cond_free_k=cond_free_k,
experimental_checkboxes=experimental_checkboxes,
progress=progress
)
except Exception as e: except Exception as e:
message = str(e) message = str(e)
if message == "Kill signal detected": if message == "Kill signal detected":
unload_tts() unload_tts()
raise gr.Error(message) raise e
return ( return (
outputs[0], outputs[0],
gr.update(value=sample, visible=sample is not None), gr.update(value=sample, visible=sample is not None),
@ -91,14 +96,8 @@ def run_generation(
gr.update(value=stats, visible=True), gr.update(value=stats, visible=True),
) )
def update_presets(value): def update_presets(value):
PRESETS = {
'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False},
'Fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 80},
'Standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200},
'High Quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400},
}
if value in PRESETS: if value in PRESETS:
preset = PRESETS[value] preset = PRESETS[value]
return (gr.update(value=preset['num_autoregressive_samples']), gr.update(value=preset['diffusion_iterations'])) return (gr.update(value=preset['num_autoregressive_samples']), gr.update(value=preset['diffusion_iterations']))
@ -117,24 +116,6 @@ def get_training_configs():
def update_training_configs(): def update_training_configs():
return gr.update(choices=get_training_list()) return gr.update(choices=get_training_list())
history_headers = {
"Name": "",
"Samples": "num_autoregressive_samples",
"Iterations": "diffusion_iterations",
"Temp.": "temperature",
"Sampler": "diffusion_sampler",
"CVVP": "cvvp_weight",
"Top P": "top_p",
"Diff. Temp.": "diffusion_temperature",
"Len Pen": "length_penalty",
"Rep Pen": "repetition_penalty",
"Cond-Free K": "cond_free_k",
"Time": "time",
"Datetime": "datetime",
"Model": "model",
"Model Hash": "model_hash",
}
def history_view_results( voice ): def history_view_results( voice ):
results = [] results = []
files = [] files = []
@ -148,10 +129,10 @@ def history_view_results( voice ):
continue continue
values = [] values = []
for k in history_headers: for k in HISTORY_HEADERS:
v = file v = file
if k != "Name": if k != "Name":
v = metadata[history_headers[k]] if history_headers[k] in metadata else '?' v = metadata[HISTORY_HEADERS[k]] if HISTORY_HEADERS[k] in metadata else '?'
values.append(v) values.append(v)
@ -193,181 +174,55 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
def prepare_dataset_proxy( voice, language, skip_existings, progress=gr.Progress(track_tqdm=True) ): def prepare_dataset_proxy( voice, language, skip_existings, progress=gr.Progress(track_tqdm=True) ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress ) return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress )
def optimize_training_settings_proxy( *args, **kwargs ): def update_args_proxy( *args ):
tup = optimize_training_settings(*args, **kwargs) kwargs = {}
keys = list(EXEC_SETTINGS.keys())
for i in range(len(args)):
k = keys[i]
v = args[i]
kwargs[k] = v
return ( update_args(**kwargs)
gr.update(value=tup[0]), def optimize_training_settings_proxy( *args ):
gr.update(value=tup[1]), kwargs = {}
gr.update(value=tup[2]), keys = list(TRAINING_SETTINGS.keys())
gr.update(value=tup[3]), for i in range(len(args)):
gr.update(value=tup[4]), k = keys[i]
gr.update(value=tup[5]), v = args[i]
gr.update(value=tup[6]), kwargs[k] = v
gr.update(value=tup[7]),
gr.update(value=tup[8]), settings, messages = optimize_training_settings(**kwargs)
"\n".join(tup[9]) output = list(settings.values())
) return output[:-1] + ["\n".join(messages)]
def import_training_settings_proxy( voice ): def import_training_settings_proxy( voice ):
indir = f'./training/{voice}/'
outdir = f'./training/{voice}-finetune/'
in_config_path = f"{indir}/train.yaml"
out_config_path = None
out_configs = []
if os.path.isdir(outdir):
out_configs = sorted([d[:-5] for d in os.listdir(outdir) if d[-5:] == ".yaml" ])
if len(out_configs) > 0:
out_config_path = f'{outdir}/{out_configs[-1]}.yaml'
config_path = out_config_path if out_config_path else in_config_path
messages = [] messages = []
with open(config_path, 'r') as file: injson = f'./training/{voice}/train.json'
config = yaml.safe_load(file) statedir = f'./training/{voice}/training_state/'
messages.append(f"Importing from: {config_path}")
dataset_path = f"./training/{voice}/train.txt"
with open(dataset_path, 'r', encoding="utf-8") as f:
lines = len(f.readlines())
messages.append(f"Basing epoch size to {lines} lines")
batch_size = config['datasets']['train']['batch_size']
gradient_accumulation_size = config['train']['mega_batch_factor']
iterations = config['train']['niter']
steps_per_iteration = int(lines / batch_size)
epochs = int(iterations / steps_per_iteration)
learning_rate = config['steps']['gpt_train']['optimizer_params']['lr']
text_ce_lr_weight = config['steps']['gpt_train']['losses']['text_ce']['weight']
learning_rate_schedule = [ int(x / steps_per_iteration) for x in config['train']['gen_lr_steps'] ]
print_rate = int(config['logger']['print_freq'] / steps_per_iteration)
save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration)
validation_rate = int(config['train']['val_freq'] / steps_per_iteration)
half_p = config['fp16']
bnb = True
statedir = f'{outdir}/training_state/'
resumes = []
resume_path = None
source_model = get_halfp_model_path() if half_p else get_model_path('autoregressive.pth')
if "pretrain_model_gpt" in config['path']:
source_model = config['path']['pretrain_model_gpt']
elif "resume_state" in config['path']:
resume_path = config['path']['resume_state']
with open(injson, 'r', encoding="utf-8") as f:
settings = json.loads(f.read())
if os.path.isdir(statedir): if os.path.isdir(statedir):
resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ]) resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ])
if len(resumes) > 0: if len(resumes) > 0:
resume_path = f'{statedir}/{resumes[-1]}.state' settings['resume_state'] = f'{statedir}/{resumes[-1]}.state'
messages.append(f"Latest resume found: {resume_path}") messages.append(f"Found most recent training state: {settings['resume_state']}")
output = list(settings.values())
messages.append(f"Imported training settings: {injson}")
return output[:-1] + ["\n".join(messages)]
def save_training_settings_proxy( *args ):
kwargs = {}
keys = list(TRAINING_SETTINGS.keys())
for i in range(len(args)):
k = keys[i]
v = args[i]
kwargs[k] = v
if "ext" in config and "bitsandbytes" in config["ext"]: settings, messages = save_training_settings(**kwargs)
bnb = config["ext"]["bitsandbytes"]
workers = config['datasets']['train']['n_workers']
messages = "\n".join(messages)
return (
epochs,
learning_rate,
text_ce_lr_weight,
learning_rate_schedule,
batch_size,
gradient_accumulation_size,
print_rate,
save_rate,
validation_rate,
resume_path,
half_p,
bnb,
workers,
source_model,
messages
)
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ):
name = f"{voice}-finetune"
dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt"
validation_name = f"{voice}-val"
validation_path = f"./training/{voice}/validation.txt"
with open(dataset_path, 'r', encoding="utf-8") as f:
lines = len(f.readlines())
messages = []
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
print_rate = int(print_rate * iterations / epochs)
save_rate = int(save_rate * iterations / epochs)
validation_rate = int(validation_rate * iterations / epochs)
validation_batch_size = int(batch_size / gradient_accumulation_size)
if iterations % save_rate != 0:
adjustment = int(iterations / save_rate) * save_rate
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {iterations} => {adjustment}")
iterations = adjustment
if not os.path.exists(validation_path):
validation_rate = iterations
validation_path = dataset_path
messages.append("Validation not found, disabling validation...")
else:
with open(validation_path, 'r', encoding="utf-8") as f:
validation_lines = len(f.readlines())
if validation_lines < validation_batch_size:
validation_batch_size = validation_lines
messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}")
if not learning_rate_schedule:
learning_rate_schedule = EPOCH_SCHEDULE
elif isinstance(learning_rate_schedule,str):
learning_rate_schedule = json.loads(learning_rate_schedule)
learning_rate_schedule = schedule_learning_rate( iterations / epochs, learning_rate_schedule )
messages.append(save_training_settings(
iterations=iterations,
batch_size=batch_size,
learning_rate=learning_rate,
text_ce_lr_weight=text_ce_lr_weight,
learning_rate_schedule=learning_rate_schedule,
gradient_accumulation_size=gradient_accumulation_size,
print_rate=print_rate,
save_rate=save_rate,
validation_rate=validation_rate,
name=name,
dataset_name=dataset_name,
dataset_path=dataset_path,
validation_name=validation_name,
validation_path=validation_path,
validation_batch_size=validation_batch_size,
output_name=f"{voice}/train.yaml",
resume_path=resume_path,
half_p=half_p,
bnb=bnb,
workers=workers,
source_model=source_model,
))
return "\n".join(messages) return "\n".join(messages)
def update_voices(): def update_voices():
@ -406,60 +261,68 @@ def setup_gradio():
autoregressive_models = get_autoregressive_models() autoregressive_models = get_autoregressive_models()
dataset_list = get_dataset_list() dataset_list = get_dataset_list()
GENERATE_SETTINGS_ARGS = list(inspect.signature(generate_proxy).parameters.keys())[:-1]
for i in range(len(GENERATE_SETTINGS_ARGS)):
arg = GENERATE_SETTINGS_ARGS[i]
GENERATE_SETTINGS[arg] = None
set_generate_settings_arg_order(GENERATE_SETTINGS_ARGS)
with gr.Blocks() as ui: with gr.Blocks() as ui:
with gr.Tab("Generate"): with gr.Tab("Generate"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
text = gr.Textbox(lines=4, label="Input Prompt") GENERATE_SETTINGS["text"] = gr.Textbox(lines=4, label="Input Prompt")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n") GENERATE_SETTINGS["delimiter"] = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True ) GENERATE_SETTINGS["emotion"] = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True )
prompt = gr.Textbox(lines=1, label="Custom Emotion") GENERATE_SETTINGS["prompt"] = gr.Textbox(lines=1, label="Custom Emotion", visible=False)
voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit GENERATE_SETTINGS["voice"] = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False ) GENERATE_SETTINGS["mic_audio"] = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False )
voice_latents_chunks = gr.Number(label="Voice Chunks", precision=0, value=0) GENERATE_SETTINGS["voice_latents_chunks"] = gr.Number(label="Voice Chunks", precision=0, value=0)
with gr.Row(): with gr.Row():
refresh_voices = gr.Button(value="Refresh Voice List") refresh_voices = gr.Button(value="Refresh Voice List")
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents") recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
voice.change( GENERATE_SETTINGS["voice"].change(
fn=update_baseline_for_latents_chunks, fn=update_baseline_for_latents_chunks,
inputs=voice, inputs=GENERATE_SETTINGS["voice"],
outputs=voice_latents_chunks outputs=GENERATE_SETTINGS["voice_latents_chunks"]
) )
voice.change( GENERATE_SETTINGS["voice"].change(
fn=lambda value: gr.update(visible=value == "microphone"), fn=lambda value: gr.update(visible=value == "microphone"),
inputs=voice, inputs=GENERATE_SETTINGS["voice"],
outputs=mic_audio, outputs=GENERATE_SETTINGS["mic_audio"],
) )
with gr.Column(): with gr.Column():
candidates = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates") GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
seed = gr.Number(value=0, precision=0, label="Seed") GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" ) preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" )
num_autoregressive_samples = gr.Slider(value=128, minimum=2, maximum=512, step=1, label="Samples")
diffusion_iterations = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations")
temperature = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature") GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=128, minimum=2, maximum=512, step=1, label="Samples")
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations")
GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings") show_experimental_settings = gr.Checkbox(label="Show Experimental Settings")
reset_generation_settings_button = gr.Button(value="Reset to Default") reset_generation_settings_button = gr.Button(value="Reset to Default")
with gr.Column(visible=False) as col: with gr.Column(visible=False) as col:
experimental_column = col experimental_column = col
experimental_checkboxes = gr.CheckboxGroup(["Half Precision", "Conditioning-Free"], value=["Conditioning-Free"], label="Experimental Flags") GENERATE_SETTINGS["experimentals"] = gr.CheckboxGroup(["Half Precision", "Conditioning-Free"], value=["Conditioning-Free"], label="Experimental Flags")
breathing_room = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size") GENERATE_SETTINGS["breathing_room"] = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size")
diffusion_sampler = gr.Radio( GENERATE_SETTINGS["diffusion_sampler"] = gr.Radio(
["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"], ["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"],
value="DDIM", label="Diffusion Samplers", type="value" value="DDIM", label="Diffusion Samplers", type="value"
) )
cvvp_weight = gr.Slider(value=0, minimum=0, maximum=1, label="CVVP Weight") GENERATE_SETTINGS["cvvp_weight"] = gr.Slider(value=0, minimum=0, maximum=1, label="CVVP Weight")
top_p = gr.Slider(value=0.8, minimum=0, maximum=1, label="Top P") GENERATE_SETTINGS["top_p"] = gr.Slider(value=0.8, minimum=0, maximum=1, label="Top P")
diffusion_temperature = gr.Slider(value=1.0, minimum=0, maximum=1, label="Diffusion Temperature") GENERATE_SETTINGS["diffusion_temperature"] = gr.Slider(value=1.0, minimum=0, maximum=1, label="Diffusion Temperature")
length_penalty = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty") GENERATE_SETTINGS["length_penalty"] = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty")
repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty") GENERATE_SETTINGS["repetition_penalty"] = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty")
cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K") GENERATE_SETTINGS["cond_free_k"] = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K")
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
submit = gr.Button(value="Generate") submit = gr.Button(value="Generate")
@ -483,7 +346,7 @@ def setup_gradio():
with gr.Tab("History"): with gr.Tab("History"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys())) history_info = gr.Dataframe(label="Results", headers=list(HISTORY_HEADERS.keys()))
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
history_voices = gr.Dropdown(choices=result_voices, label="Voice", type="value", value=result_voices[0] if len(result_voices) > 0 else "") history_voices = gr.Dropdown(choices=result_voices, label="Voice", type="value", value=result_voices[0] if len(result_voices) > 0 else "")
@ -521,51 +384,40 @@ def setup_gradio():
with gr.Tab("Generate Configuration"): with gr.Tab("Generate Configuration"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
training_settings = [ TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0)
gr.Number(label="Epochs", value=500, precision=0),
]
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
training_settings = training_settings + [ TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6)
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6), TRAINING_SETTINGS["text_ce_lr_weight"] = gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1)
gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1),
] TRAINING_SETTINGS["learning_rate_schedule"] = gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE))
training_settings = training_settings + [
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
]
with gr.Row(): with gr.Row():
training_settings = training_settings + [ TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
gr.Number(label="Batch Size", value=128, precision=0), TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0)
gr.Number(label="Gradient Accumulation Size", value=4, precision=0),
]
with gr.Row(): with gr.Row():
training_settings = training_settings + [ TRAINING_SETTINGS["print_rate"] = gr.Number(label="Print Frequency (in epochs)", value=5, precision=0)
gr.Number(label="Print Frequency (in epochs)", value=5, precision=0), TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0)
gr.Number(label="Save Frequency (in epochs)", value=5, precision=0), TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0)
gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0),
]
training_settings = training_settings + [
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
]
with gr.Row(): with gr.Row():
training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp) TRAINING_SETTINGS["half_p"] = gr.Checkbox(label="Half Precision", value=args.training_default_halfp)
training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb) TRAINING_SETTINGS["bitsandbytes"] = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb)
training_workers = gr.Number(label="Worker Processes", value=2, precision=0) TRAINING_SETTINGS["workers"] = gr.Number(label="Worker Processes", value=2, precision=0)
TRAINING_SETTINGS["gpus"] = gr.Number(label="GPUs", value=get_device_count(), precision=0)
source_model = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] ) TRAINING_SETTINGS["source_model"] = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] )
dataset_list_dropdown = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" ) TRAINING_SETTINGS["resume_state"] = gr.Textbox(label="Resume State Path", placeholder="./training/${voice}/training_state/${last_state}.state")
training_settings = training_settings + [ training_halfp, training_bnb, training_workers, source_model, dataset_list_dropdown ]
TRAINING_SETTINGS["voice"] = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" )
with gr.Row(): with gr.Row():
refresh_dataset_list = gr.Button(value="Refresh Dataset List") training_refresh_dataset = gr.Button(value="Refresh Dataset List")
import_dataset_button = gr.Button(value="Reuse/Import Dataset") training_import_settings = gr.Button(value="Reuse/Import Dataset")
with gr.Column(): with gr.Column():
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) training_configuration_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
with gr.Row(): with gr.Row():
optimize_yaml_button = gr.Button(value="Validate Training Configuration") training_optimize_configuration = gr.Button(value="Validate Training Configuration")
save_yaml_button = gr.Button(value="Save Training Configuration") training_save_configuration = gr.Button(value="Save Training Configuration")
with gr.Tab("Run Training"): with gr.Tab("Run Training"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -588,9 +440,7 @@ def setup_gradio():
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
with gr.Row(): training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
with gr.Row(): with gr.Row():
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop") stop_training_button = gr.Button(value="Stop")
@ -599,43 +449,40 @@ def setup_gradio():
with gr.Row(): with gr.Row():
exec_inputs = [] exec_inputs = []
with gr.Column(): with gr.Column():
exec_inputs = exec_inputs + [ EXEC_SETTINGS['listen'] = gr.Textbox(label="Listen", value=args.listen, placeholder="127.0.0.1:7860/")
gr.Textbox(label="Listen", value=args.listen, placeholder="127.0.0.1:7860/"), EXEC_SETTINGS['share'] = gr.Checkbox(label="Public Share Gradio", value=args.share)
gr.Checkbox(label="Public Share Gradio", value=args.share), EXEC_SETTINGS['check_for_updates'] = gr.Checkbox(label="Check For Updates", value=args.check_for_updates)
gr.Checkbox(label="Check For Updates", value=args.check_for_updates), EXEC_SETTINGS['models_from_local_only'] = gr.Checkbox(label="Only Load Models Locally", value=args.models_from_local_only)
gr.Checkbox(label="Only Load Models Locally", value=args.models_from_local_only), EXEC_SETTINGS['low_vram'] = gr.Checkbox(label="Low VRAM", value=args.low_vram)
gr.Checkbox(label="Low VRAM", value=args.low_vram), EXEC_SETTINGS['embed_output_metadata'] = gr.Checkbox(label="Embed Output Metadata", value=args.embed_output_metadata)
gr.Checkbox(label="Embed Output Metadata", value=args.embed_output_metadata), EXEC_SETTINGS['latents_lean_and_mean'] = gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean)
gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean), EXEC_SETTINGS['voice_fixer'] = gr.Checkbox(label="Use Voice Fixer on Generated Output", value=args.voice_fixer)
gr.Checkbox(label="Use Voice Fixer on Generated Output", value=args.voice_fixer), EXEC_SETTINGS['voice_fixer_use_cuda'] = gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda)
gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda), EXEC_SETTINGS['force_cpu_for_conditioning_latents'] = gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents)
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents), EXEC_SETTINGS['defer_tts_load'] = gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load)
gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load), EXEC_SETTINGS['prune_nonfinal_outputs'] = gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs)
gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs), EXEC_SETTINGS['device_override'] = gr.Textbox(label="Device Override", value=args.device_override)
gr.Textbox(label="Device Override", value=args.device_override),
]
with gr.Column(): with gr.Column():
exec_inputs = exec_inputs + [ EXEC_SETTINGS['sample_batch_size'] = gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size)
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size), EXEC_SETTINGS['concurrency_count'] = gr.Number(label="Gradio Concurrency Count", precision=0, value=args.concurrency_count)
gr.Number(label="Gradio Concurrency Count", precision=0, value=args.concurrency_count), EXEC_SETTINGS['autocalculate_voice_chunk_duration_size'] = gr.Number(label="Auto-Calculate Voice Chunk Duration (in seconds)", precision=0, value=args.autocalculate_voice_chunk_duration_size)
gr.Number(label="Auto-Calculate Voice Chunk Duration (in seconds)", precision=0, value=args.autocalculate_voice_chunk_duration_size), EXEC_SETTINGS['output_volume'] = gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume)
gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
]
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
vocoder_models = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1]) EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1])
whisper_backend = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend) EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
whisper_model_dropdown = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model) EXEC_SETTINGS['whisper_model'] = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
exec_inputs = exec_inputs + [ autoregressive_model_dropdown, vocoder_models, whisper_backend, whisper_model_dropdown, training_halfp, training_bnb ] EXEC_SETTINGS['training_default_halfp'] = TRAINING_SETTINGS['half_p']
EXEC_SETTINGS['training_default_bnb'] = TRAINING_SETTINGS['bitsandbytes']
with gr.Row(): with gr.Row():
autoregressive_models_update_button = gr.Button(value="Refresh Model List") autoregressive_models_update_button = gr.Button(value="Refresh Model List")
gr.Button(value="Check for Updates").click(check_for_updates) gr.Button(value="Check for Updates").click(check_for_updates)
gr.Button(value="(Re)Load TTS").click( gr.Button(value="(Re)Load TTS").click(
reload_tts, reload_tts,
inputs=autoregressive_model_dropdown, inputs=EXEC_SETTINGS['autoregressive_model'],
outputs=None outputs=None
) )
# kill_button = gr.Button(value="Close UI") # kill_button = gr.Button(value="Close UI")
@ -648,49 +495,26 @@ def setup_gradio():
autoregressive_models_update_button.click( autoregressive_models_update_button.click(
update_model_list_proxy, update_model_list_proxy,
inputs=autoregressive_model_dropdown, inputs=EXEC_SETTINGS['autoregressive_model'],
outputs=autoregressive_model_dropdown, outputs=EXEC_SETTINGS['autoregressive_model'],
) )
for i in exec_inputs: exec_inputs = list(EXEC_SETTINGS.values())
i.change( fn=update_args, inputs=exec_inputs ) for k in EXEC_SETTINGS:
EXEC_SETTINGS[k].change( fn=update_args_proxy, inputs=exec_inputs )
autoregressive_model_dropdown.change( EXEC_SETTINGS['autoregressive_model'].change(
fn=update_autoregressive_model, fn=update_autoregressive_model,
inputs=autoregressive_model_dropdown, inputs=EXEC_SETTINGS['autoregressive_model'],
outputs=None outputs=None
) )
vocoder_models.change( EXEC_SETTINGS['vocoder_model'].change(
fn=update_vocoder_model, fn=update_vocoder_model,
inputs=vocoder_models, inputs=EXEC_SETTINGS['vocoder_model'],
outputs=None outputs=None
) )
input_settings = [
text,
delimiter,
emotion,
prompt,
voice,
mic_audio,
voice_latents_chunks,
seed,
candidates,
num_autoregressive_samples,
diffusion_iterations,
temperature,
diffusion_sampler,
breathing_room,
cvvp_weight,
top_p,
diffusion_temperature,
length_penalty,
repetition_penalty,
cond_free_k,
experimental_checkboxes,
]
history_voices.change( history_voices.change(
fn=history_view_results, fn=history_view_results,
inputs=history_voices, inputs=history_voices,
@ -734,45 +558,46 @@ def setup_gradio():
preset.change(fn=update_presets, preset.change(fn=update_presets,
inputs=preset, inputs=preset,
outputs=[ outputs=[
num_autoregressive_samples, GENERATE_SETTINGS['num_autoregressive_samples'],
diffusion_iterations, GENERATE_SETTINGS['diffusion_iterations'],
], ],
) )
recompute_voice_latents.click(compute_latents_proxy, recompute_voice_latents.click(compute_latents_proxy,
inputs=[ inputs=[
voice, GENERATE_SETTINGS['voice'],
voice_latents_chunks, GENERATE_SETTINGS['voice_latents_chunks'],
], ],
outputs=voice, outputs=GENERATE_SETTINGS['voice'],
) )
emotion.change( GENERATE_SETTINGS['emotion'].change(
fn=lambda value: gr.update(visible=value == "Custom"), fn=lambda value: gr.update(visible=value == "Custom"),
inputs=emotion, inputs=GENERATE_SETTINGS['emotion'],
outputs=prompt outputs=GENERATE_SETTINGS['prompt']
) )
mic_audio.change(fn=lambda value: gr.update(value="microphone"), GENERATE_SETTINGS['mic_audio'].change(fn=lambda value: gr.update(value="microphone"),
inputs=mic_audio, inputs=GENERATE_SETTINGS['mic_audio'],
outputs=voice outputs=GENERATE_SETTINGS['voice']
) )
refresh_voices.click(update_voices, refresh_voices.click(update_voices,
inputs=None, inputs=None,
outputs=[ outputs=[
voice, GENERATE_SETTINGS['voice'],
dataset_settings[0], dataset_settings[0],
history_voices history_voices
] ]
) )
generate_settings = list(GENERATE_SETTINGS.values())
submit.click( submit.click(
lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)), lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)),
outputs=[source_sample, candidates_list, generation_results], outputs=[source_sample, candidates_list, generation_results],
) )
submit_event = submit.click(run_generation, submit_event = submit.click(generate_proxy,
inputs=input_settings, inputs=generate_settings,
outputs=[output_audio, source_sample, candidates_list, generation_results], outputs=[output_audio, source_sample, candidates_list, generation_results],
api_name="generate", api_name="generate",
) )
@ -780,13 +605,13 @@ def setup_gradio():
copy_button.click(import_generate_settings, copy_button.click(import_generate_settings,
inputs=audio_in, # JSON elements cannot be used as inputs inputs=audio_in, # JSON elements cannot be used as inputs
outputs=input_settings outputs=generate_settings
) )
reset_generation_settings_button.click( reset_generation_settings_button.click(
fn=reset_generation_settings, fn=reset_generation_settings,
inputs=None, inputs=None,
outputs=input_settings outputs=generate_settings
) )
history_copy_settings_button.click(history_copy_settings, history_copy_settings_button.click(history_copy_settings,
@ -794,7 +619,7 @@ def setup_gradio():
history_voices, history_voices,
history_results_list, history_results_list,
], ],
outputs=input_settings outputs=generate_settings
) )
refresh_configs.click( refresh_configs.click(
@ -806,7 +631,6 @@ def setup_gradio():
inputs=[ inputs=[
training_configs, training_configs,
verbose_training, verbose_training,
training_gpu_count,
training_keep_x_past_datasets, training_keep_x_past_datasets,
], ],
outputs=[ outputs=[
@ -855,38 +679,28 @@ def setup_gradio():
], ],
outputs=prepare_dataset_output #console_output outputs=prepare_dataset_output #console_output
) )
refresh_dataset_list.click(
training_refresh_dataset.click(
lambda: gr.update(choices=get_dataset_list()), lambda: gr.update(choices=get_dataset_list()),
inputs=None, inputs=None,
outputs=dataset_list_dropdown, outputs=TRAINING_SETTINGS["voice"],
) )
optimize_yaml_button.click(optimize_training_settings_proxy, training_settings = list(TRAINING_SETTINGS.values())
training_optimize_configuration.click(optimize_training_settings_proxy,
inputs=training_settings, inputs=training_settings,
outputs=training_settings[1:10] + [save_yaml_output] #console_output outputs=training_settings[:-1] + [training_configuration_output] #console_output
) )
import_dataset_button.click(import_training_settings_proxy, training_import_settings.click(import_training_settings_proxy,
inputs=dataset_list_dropdown, inputs=TRAINING_SETTINGS['voice'],
outputs=training_settings[:14] + [save_yaml_output] #console_output outputs=training_settings[:-1] + [training_configuration_output] #console_output
) )
save_yaml_button.click(save_training_settings_proxy, training_save_configuration.click(save_training_settings_proxy,
inputs=training_settings, inputs=training_settings,
outputs=save_yaml_output #console_output outputs=training_configuration_output #console_output
) )
"""
def kill_process():
ui.close()
exit()
kill_button.click(
kill_process,
inputs=None,
outputs=None
)
"""
if os.path.isfile('./config/generate.json'): if os.path.isfile('./config/generate.json'):
ui.load(import_generate_settings, inputs=None, outputs=input_settings) ui.load(import_generate_settings, inputs=None, outputs=generate_settings)
if args.check_for_updates: if args.check_for_updates:
ui.load(check_for_updates) ui.load(check_for_updates)