diff --git a/models/.template.yaml b/models/.template.yaml index 15cfbe3..90ab039 100755 --- a/models/.template.yaml +++ b/models/.template.yaml @@ -1,16 +1,18 @@ -name: ${name} +name: ${voice} model: extensibletrainer scale: 1 gpu_ids: [0] # Superfluous, redundant, unnecessary, the way you launch the training script will set this start_step: 0 checkpointing_enabled: true -fp16: ${float16} +fp16: ${half_p} +bitsandbytes: ${bitsandbytes} +gpus: ${gpus} wandb: false use_tb_logger: true datasets: train: - name: ${dataset_name} + name: training n_workers: ${workers} batch_size: ${batch_size} mode: paired_voice_audio @@ -27,7 +29,7 @@ datasets: tokenizer_vocab: ./models/tortoise/bpe_lowercase_asr_256.json load_aligned_codes: False val: # I really do not care about validation right now - name: ${validation_name} + name: validation n_workers: ${workers} batch_size: ${validation_batch_size} mode: paired_voice_audio @@ -114,8 +116,8 @@ networks: #only_alignment_head: False # uv3/4 path: - ${pretrain_model_gpt} strict_load: true + ${source_model} ${resume_state} train: diff --git a/src/cull_dataset.py b/src/cull_dataset.py deleted file mode 100755 index 0572405..0000000 --- a/src/cull_dataset.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -import sys - -indir = f'./training/{sys.argv[1]}/' -cap = int(sys.argv[2]) - -if not os.path.isdir(indir): - raise Exception(f"Invalid directory: {indir}") - -if not os.path.exists(f'{indir}/train.txt'): - raise Exception(f"Missing dataset: {indir}/train.txt") - -with open(f'{indir}/train.txt', 'r', encoding="utf-8") as f: - lines = f.readlines() - -validation = [] -training = [] - -for line in lines: - split = line.split("|") - filename = split[0] - text = split[1] - - if len(text) < cap: - validation.append(line.strip()) - else: - training.append(line.strip()) - -with open(f'{indir}/train_culled.txt', 'w', encoding="utf-8") as f: - f.write("\n".join(training)) - -with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: - f.write("\n".join(validation)) - -print(f"Culled {len(validation)} lines") \ No newline at end of file diff --git a/src/train.py b/src/train.py index 144ecc0..78f4ff3 100755 --- a/src/train.py +++ b/src/train.py @@ -46,6 +46,7 @@ sys.path.insert(0, './dlas/') # without kludge, it'll have to be accessible as `codes` and not `dlas` import torch +import datetime from codes import train as tr from utils import util, options as option @@ -71,7 +72,7 @@ def train(yaml, launcher='none'): print('Disabled distributed training.') else: opt['dist'] = True - tr.init_dist('nccl') + tr.init_dist('nccl', timeout=datetime.timedelta(seconds=5*60)) trainer.world_size = torch.distributed.get_world_size() trainer.rank = torch.distributed.get_rank() torch.cuda.set_device(torch.distributed.get_rank()) diff --git a/src/utils.py b/src/utils.py index 75ab473..d7f3f44 100755 --- a/src/utils.py +++ b/src/utils.py @@ -34,7 +34,7 @@ from datetime import timedelta from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir from tortoise.utils.text import split_and_recombine_text -from tortoise.utils.device import get_device_name, set_device_name +from tortoise.utils.device import get_device_name, set_device_name, get_device_count MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" @@ -44,6 +44,8 @@ WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"] VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band'] +GENERATE_SETTINGS_ARGS = None + EPOCH_SCHEDULE = [ 9, 18, 25, 33 ] args = None @@ -56,30 +58,17 @@ training_state = None current_voice = None -def generate( - text, - delimiter, - emotion, - prompt, - voice, - mic_audio, - voice_latents_chunks, - seed, - candidates, - num_autoregressive_samples, - diffusion_iterations, - temperature, - diffusion_sampler, - breathing_room, - cvvp_weight, - top_p, - diffusion_temperature, - length_penalty, - repetition_penalty, - cond_free_k, - experimental_checkboxes, - progress=None -): +def generate(**kwargs): + parameters = {} + parameters.update(kwargs) + + voice = parameters['voice'] + progress = parameters['progress'] if 'progress' in parameters else None + if parameters['seed'] == 0: + parameters['seed'] = None + + usedSeed = parameters['seed'] + global args global tts @@ -90,6 +79,8 @@ def generate( # should check if it's loading or unloaded, and load it if it's unloaded if tts_loading: raise Exception("TTS is still initializing...") + if progress is not None: + progress(0, "Initializing TTS...") load_tts() if hasattr(tts, "loading") and tts.loading: raise Exception("TTS is still initializing...") @@ -100,9 +91,6 @@ def generate( conditioning_latents =None sample_voice = None - if seed == 0: - seed = None - voice_cache = {} def fetch_voice( voice ): print(f"Loading voice: {voice} with model {tts.autoregressive_model_hash[:8]}") @@ -112,9 +100,9 @@ def generate( sample_voice = None if voice == "microphone": - if mic_audio is None: + if parameters['mic_audio'] is None: raise Exception("Please provide audio from mic when choosing `microphone` as a voice input") - voice_samples, conditioning_latents = [load_audio(mic_audio, tts.input_sample_rate)], None + voice_samples, conditioning_latents = [load_audio(parameters['mic_audio'], tts.input_sample_rate)], None elif voice == "random": voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents() else: @@ -125,7 +113,7 @@ def generate( if voice_samples and len(voice_samples) > 0: if conditioning_latents is None: - conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks) + conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=parameters['voice_latents_chunks']) sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu() voice_samples = None @@ -135,30 +123,30 @@ def generate( def get_settings( override=None ): settings = { - 'temperature': float(temperature), + 'temperature': float(parameters['temperature']), - 'top_p': float(top_p), - 'diffusion_temperature': float(diffusion_temperature), - 'length_penalty': float(length_penalty), - 'repetition_penalty': float(repetition_penalty), - 'cond_free_k': float(cond_free_k), + 'top_p': float(parameters['top_p']), + 'diffusion_temperature': float(parameters['diffusion_temperature']), + 'length_penalty': float(parameters['length_penalty']), + 'repetition_penalty': float(parameters['repetition_penalty']), + 'cond_free_k': float(parameters['cond_free_k']), - 'num_autoregressive_samples': num_autoregressive_samples, + 'num_autoregressive_samples': parameters['num_autoregressive_samples'], 'sample_batch_size': args.sample_batch_size, - 'diffusion_iterations': diffusion_iterations, + 'diffusion_iterations': parameters['diffusion_iterations'], 'voice_samples': None, 'conditioning_latents': None, - 'use_deterministic_seed': seed, + 'use_deterministic_seed': parameters['seed'], 'return_deterministic_state': True, - 'k': candidates, - 'diffusion_sampler': diffusion_sampler, - 'breathing_room': breathing_room, - 'progress': progress, - 'half_p': "Half Precision" in experimental_checkboxes, - 'cond_free': "Conditioning-Free" in experimental_checkboxes, - 'cvvp_amount': cvvp_weight, + 'k': parameters['candidates'], + 'diffusion_sampler': parameters['diffusion_sampler'], + 'breathing_room': parameters['breathing_room'], + 'progress': parameters['progress'], + 'half_p': "Half Precision" in parameters['experimentals'], + 'cond_free': "Conditioning-Free" in parameters['experimentals'], + 'cvvp_amount': parameters['cvvp_weight'], 'autoregressive_model': args.autoregressive_model, } @@ -182,11 +170,11 @@ def generate( # clamp it down for the insane users who want this # it would be wiser to enforce the sample size to the batch size, but this is what the user wants - sample_batch_size = args.sample_batch_size - if not sample_batch_size: - sample_batch_size = tts.autoregressive_batch_size - if num_autoregressive_samples < sample_batch_size: - settings['sample_batch_size'] = num_autoregressive_samples + settings['sample_batch_size'] = args.sample_batch_size + if not settings['sample_batch_size']: + settings['sample_batch_size'] = tts.autoregressive_batch_size + if settings['num_autoregressive_samples'] < settings['sample_batch_size']: + settings['sample_batch_size'] = settings['num_autoregressive_samples'] if settings['conditioning_latents'] is not None and len(settings['conditioning_latents']) == 2 and settings['cvvp_amount'] > 0: print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents with 'Slimmer voice latents' unchecked.") @@ -194,15 +182,15 @@ def generate( return settings - if not delimiter: - delimiter = "\n" - elif delimiter == "\\n": - delimiter = "\n" + if not parameters['delimiter']: + parameters['delimiter'] = "\n" + elif parameters['delimiter'] == "\\n": + parameters['delimiter'] = "\n" - if delimiter and delimiter != "" and delimiter in text: - texts = text.split(delimiter) + if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']: + texts = parameters['text'].split(parameters['delimiter']) else: - texts = split_and_recombine_text(text) + texts = split_and_recombine_text(parameters['text']) full_start_time = time.time() @@ -248,37 +236,23 @@ def generate( name = f"{name}_combined" elif len(texts) > 1: name = f"{name}_{line}" - if candidates > 1: + if parameters['candidates'] > 1: name = f"{name}_{candidate}" return name def get_info( voice, settings = None, latents = True ): - info = { - 'text': text, - 'delimiter': '\\n' if delimiter and delimiter == "\n" else delimiter, - 'emotion': emotion, - 'prompt': prompt, - 'voice': voice, - 'seed': seed, - 'candidates': candidates, - 'num_autoregressive_samples': num_autoregressive_samples, - 'diffusion_iterations': diffusion_iterations, - 'temperature': temperature, - 'diffusion_sampler': diffusion_sampler, - 'breathing_room': breathing_room, - 'cvvp_weight': cvvp_weight, - 'top_p': top_p, - 'diffusion_temperature': diffusion_temperature, - 'length_penalty': length_penalty, - 'repetition_penalty': repetition_penalty, - 'cond_free_k': cond_free_k, - 'experimentals': experimental_checkboxes, - 'time': time.time()-full_start_time, + info = {} + info.update(parameters) + info['time'] = time.time()-full_start_time, - 'datetime': datetime.now().isoformat(), - 'model': tts.autoregressive_model_path, - 'model_hash': tts.autoregressive_model_hash - } + info['datetime'] = datetime.now().isoformat(), + info['model'] = tts.autoregressive_model_path, + info['model_hash'] = tts.autoregressive_model_hash + info['progress'] = None + del info['progress'] + + if info['delimiter'] == "\n": + info['delimiter'] = "\\n" if settings is not None: for k in settings: @@ -319,11 +293,11 @@ def generate( return info for line, cut_text in enumerate(texts): - if emotion == "Custom": - if prompt and prompt.strip() != "": - cut_text = f"[{prompt},] {cut_text}" - elif emotion != "None" and emotion: - cut_text = f"[I am really {emotion.lower()},] {cut_text}" + if parameters['emotion'] == "Custom": + if parameters['prompt'] and parameters['prompt'].strip() != "": + cut_text = f"[{parameters['prompt']},] {cut_text}" + elif parameters['emotion'] != "None" and parameters['emotion']: + cut_text = f"[I am really {parameters['emotion'].lower()},] {cut_text}" progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' print(f"{progress.msg_prefix} Generating line: {cut_text}") @@ -343,10 +317,10 @@ def generate( settings = get_settings( override=override ) gen, additionals = tts.tts(cut_text, **settings ) - seed = additionals[0] + parameters['seed'] = additionals[0] run_time = time.time()-start_time print(f"Generating line took {run_time} seconds") - + if not isinstance(gen, list): gen = [gen] @@ -382,7 +356,7 @@ def generate( torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) output_voices = [] - for candidate in range(candidates): + for candidate in range(parameters['candidates']): if len(texts) > 1: audio_clips = [] for line in range(len(texts)): @@ -466,7 +440,7 @@ def generate( info = get_info(voice=voice, latents=False) print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n") - info['seed'] = seed + info['seed'] = usedSeed if 'latents' in info: del info['latents'] @@ -475,7 +449,7 @@ def generate( f.write(json.dumps(info, indent='\t') ) stats = [ - [ seed, "{:.3f}".format(info['time']) ] + [ parameters['seed'], "{:.3f}".format(info['time']) ] ] return ( @@ -609,14 +583,16 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog # superfluous, but it cleans up some things class TrainingState(): - def __init__(self, config_path, keep_x_past_checkpoints=0, start=True, gpus=1): + def __init__(self, config_path, keep_x_past_checkpoints=0, start=True): # parse config to get its iteration with open(config_path, 'r') as file: self.config = yaml.safe_load(file) + gpus = self.config["gpus"] + self.killed = False - self.dataset_dir = f"./training/{self.config['name']}/" + self.dataset_dir = f"./training/{self.config['name']}/finetune/" self.batch_size = self.config['datasets']['train']['batch_size'] self.dataset_path = self.config['datasets']['train']['path'] with open(self.dataset_path, 'r', encoding="utf-8") as f: @@ -996,7 +972,7 @@ except Exception as e: print(e) pass -def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)): +def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)): global training_state if training_state and training_state.process: return "Training already in progress" @@ -1008,26 +984,11 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0, # I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process torch.multiprocessing.freeze_support() - # edit any gpu-count-specific variables - with open(config_path, 'r', encoding="utf-8") as f: - yaml_string = f.read() - edited = False - if gpus > 1: - yaml_string = yaml_string.replace(" adamw ", " adamw_zero ") - edited = True - else: - yaml_string = yaml_string.replace(" adamw_zero ", " adamw ") - edited = True - if edited: - print(f'Modified YAML config') - with open(config_path, 'w', encoding="utf-8") as f: - f.write(yaml_string) - unload_tts() unload_whisper() unload_voicefixer() - training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints, gpus=gpus) + training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints) for line in iter(training_state.process.stdout.readline, ""): if training_state.killed: @@ -1169,7 +1130,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres if whisper_model is None: load_whisper_model(language=language) - os.makedirs(outdir, exist_ok=True) + os.makedirs(f'{outdir}/audio/', exist_ok=True) results = {} transcription = [] @@ -1216,10 +1177,10 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres print(f"Error with {sliced_name}, skipping...") continue - torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate) + torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate) idx = idx + 1 - line = f"{sliced_name}|{segment['text'].strip()}" + line = f"audio/{sliced_name}|{segment['text'].strip()}" transcription.append(line) with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: f.write(f'\n{line}') @@ -1283,125 +1244,142 @@ def calc_iterations( epochs, lines, batch_size ): def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ): return [int(iterations * d) for d in schedule] -def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ): - name = f"{voice}-finetune" - dataset_path = f"./training/{voice}/train.txt" +def optimize_training_settings( **kwargs ): + messages = [] + settings = {} + settings.update(kwargs) + dataset_path = f"./training/{settings['voice']}/train.txt" with open(dataset_path, 'r', encoding="utf-8") as f: lines = len(f.readlines()) - messages = [] + if settings['batch_size'] > lines: + settings['batch_size'] = lines + messages.append(f"Batch size is larger than your dataset, clamping batch size to: {settings['batch_size']}") - if batch_size > lines: - batch_size = lines - messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}") - - if batch_size % lines != 0: - nearest_slice = int(lines / batch_size) + 1 - batch_size = int(lines / nearest_slice) - messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)") + if settings['batch_size'] % lines != 0: + nearest_slice = int(lines / settings['batch_size']) + 1 + settings['batch_size'] = int(lines / nearest_slice) + messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {settings['batch_size']} ({nearest_slice} steps per epoch)") - if gradient_accumulation_size == 0: - gradient_accumulation_size = 1 + if settings['gradient_accumulation_size'] == 0: + settings['gradient_accumulation_size'] = 1 - if batch_size / gradient_accumulation_size < 2: - gradient_accumulation_size = int(batch_size / 2) - if gradient_accumulation_size == 0: - gradient_accumulation_size = 1 + if settings['batch_size'] / settings['gradient_accumulation_size'] < 2: + settings['gradient_accumulation_size'] = int(settings['batch_size'] / 2) + if settings['gradient_accumulation_size'] == 0: + settings['gradient_accumulation_size'] = 1 - messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}") - elif batch_size % gradient_accumulation_size != 0: - gradient_accumulation_size = int(batch_size / gradient_accumulation_size) - if gradient_accumulation_size == 0: - gradient_accumulation_size = 1 + messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {settings['gradient_accumulation_size']}") + elif settings['batch_size'] % settings['gradient_accumulation_size'] != 0: + settings['gradient_accumulation_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size']) + if settings['gradient_accumulation_size'] == 0: + settings['gradient_accumulation_size'] = 1 - messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}") + messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {settings['gradient_accumulation_size']}") - iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size) + iterations = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) - if epochs < print_rate: - print_rate = epochs - messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}") + if settings['epochs'] < settings['print_rate']: + settings['print_rate'] = settings['epochs'] + messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {settings['print_rate']}") - if epochs < save_rate: - save_rate = epochs - messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}") + if settings['epochs'] < settings['save_rate']: + settings['save_rate'] = settings['epochs'] + messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {settings['save_rate']}") - if epochs < validation_rate: - validation_rate = epochs - messages.append(f"Validation rate is too small for the given iteration step, clamping validation rate to: {validation_rate}") + if settings['epochs'] < settings['validation_rate']: + settings['validation_rate'] = settings['epochs'] + messages.append(f"Validation rate is too small for the given iteration step, clamping validation rate to: {settings['validation_rate']}") - if resume_path and not os.path.exists(resume_path): - resume_path = None + if settings['resume_state'] and not os.path.exists(settings['resume_state']): + settings['resume_state'] = None messages.append("Resume path specified, but does not exist. Disabling...") - if bnb: + if settings['bitsandbytes']: messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !") - if half_p: - if bnb: - half_p = False + if settings['half_p']: + if settings['bitsandbytes']: + settings['half_p'] = False messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...") else: messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !") if not os.path.exists(get_halfp_model_path()): convert_to_halfp() - messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)") + messages.append(f"For {settings['epochs']} epochs with {lines} lines in batches of {settings['batch_size']}, iterating for {iterations} steps ({int(iterations / settings['epochs'])} steps per epoch)") - return ( - learning_rate, - text_ce_lr_weight, - learning_rate_schedule, - batch_size, - gradient_accumulation_size, - print_rate, - save_rate, - validation_rate, - resume_path, - messages - ) + return settings, messages -def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_scheme=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, validation_batch_size=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ): - if not source_model: - source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth" +def save_training_settings( **kwargs ): + messages = [] + settings = {} + settings.update(kwargs) - settings = { - "iterations": iterations if iterations else 500, - "batch_size": batch_size if batch_size else 64, - "learning_rate": learning_rate if learning_rate else 1e-5, - "gradient_accumulation_size": gradient_accumulation_size if gradient_accumulation_size else 4, - "print_rate": print_rate if print_rate else 1, - "save_rate": save_rate if save_rate else 50, - "name": name if name else "finetune", - "dataset_name": dataset_name if dataset_name else "finetune", - "dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt", - "validation_name": validation_name if validation_name else "finetune", - "validation_path": validation_path if validation_path else "./training/finetune/train.txt", - 'validation_rate': validation_rate if validation_rate else iterations, - "validation_batch_size": validation_batch_size if validation_batch_size else batch_size, - 'validation_enabled': "true", + settings['dataset_path'] = f"./training/{settings['voice']}/train.txt" + settings['validation_path'] = f"./training/{settings['voice']}/validation.txt" - "text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01, + with open(settings['dataset_path'], 'r', encoding="utf-8") as f: + lines = len(f.readlines()) - 'resume_state': f"resume_state: '{resume_path}'", - 'pretrain_model_gpt': f"pretrain_model_gpt: '{source_model}'", + if not settings['source_model'] or settings['source_model'] == "auto": + settings['source_model'] = f"./models/tortoise/autoregressive{'_half' if settings['half_p'] else ''}.pth" - 'float16': 'true' if half_p else 'false', - 'bitsandbytes': 'true' if bnb else 'false', + if settings['half_p']: + if not os.path.exists(get_halfp_model_path()): + convert_to_halfp() - 'workers': workers if workers else 2, - } + settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) + messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps") + + settings['print_rate'] = int(settings['print_rate'] * settings['iterations'] / settings['epochs']) + settings['save_rate'] = int(settings['save_rate'] * settings['iterations'] / settings['epochs']) + settings['validation_rate'] = int(settings['validation_rate'] * settings['iterations'] / settings['epochs']) + + settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size']) + + settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) + if settings['iterations'] % settings['save_rate'] != 0: + adjustment = int(settings['iterations'] / settings['save_rate']) * settings['save_rate'] + messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {settings['iterations']} => {adjustment}") + settings['iterations'] = adjustment + + if not os.path.exists(settings['validation_path']): + settings['validation_enabled'] = False + messages.append("Validation not found, disabling validation...") + elif settings['validation_batch_size'] == 0: + settings['validation_enabled'] = False + messages.append("Validation batch size == 0, disabling validation...") + else: + settings['validation_enabled'] = True + with open(settings['validation_path'], 'r', encoding="utf-8") as f: + validation_lines = len(f.readlines()) + + if validation_lines < settings['validation_batch_size']: + settings['validation_batch_size'] = validation_lines + messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}") + + + if settings['gpus'] > get_device_count(): + settings['gpus'] = get_device_count() LEARNING_RATE_SCHEMES = ["MultiStepLR", "CosineAnnealingLR_Restart"] - if learning_rate_scheme not in LEARNING_RATE_SCHEMES: - learning_rate_scheme = LEARNING_RATE_SCHEMES[0] + if 'learning_rate_scheme' not in settings or settings['learning_rate_scheme'] not in LEARNING_RATE_SCHEMES: + settings['learning_rate_scheme'] = LEARNING_RATE_SCHEMES[0] - learning_rate_schema = [f"default_lr_scheme: {learning_rate_scheme}"] - if learning_rate_scheme == "MultiStepLR": - learning_rate_schema.append(f" gen_lr_steps: {learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE}") + learning_rate_schema = [f"default_lr_scheme: {settings['learning_rate_scheme']}"] + if settings['learning_rate_scheme'] == "MultiStepLR": + if not settings['learning_rate_schedule']: + settings['learning_rate_schedule'] = EPOCH_SCHEDULE + elif isinstance(settings['learning_rate_schedule'],str): + settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule']) + + settings['learning_rate_schedule'] = schedule_learning_rate( settings['iterations'] / settings['epochs'], settings['learning_rate_schedule'] ) + + learning_rate_schema.append(f" gen_lr_steps: {settings['learning_rate_schedule']}") learning_rate_schema.append(f" lr_gamma: 0.5") - elif learning_rate_scheme == "CosineAnnealingLR_Restart": + elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart": learning_rate_schema.append(f" T_period: [120000, 120000, 120000]") learning_rate_schema.append(f" warmup: 10000") learning_rate_schema.append(f" eta_min: .01") @@ -1409,23 +1387,26 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig learning_rate_schema.append(f" restart_weights: [.5, .25]") settings['learning_rate_scheme'] = "\n".join(learning_rate_schema) - if resume_path: + """ + if resume_state: settings['pretrain_model_gpt'] = f"# {settings['pretrain_model_gpt']}" else: - settings['resume_state'] = f"# resume_state: './training/{name if name else 'finetune'}/training_state/#.state'" + settings['resume_state'] = f"# resume_state: './training/{voice}/training_state/#.state'" # also disable validation if it doesn't make sense to do it if settings['dataset_path'] == settings['validation_path'] or not os.path.exists(settings['validation_path']): settings['validation_enabled'] = 'false' + """ + outjson = f'./training/{settings["voice"]}/train.json' + with open(outjson, 'w', encoding="utf-8") as f: + f.write(json.dumps(settings, indent='\t') ) - - if half_p: - if not os.path.exists(get_halfp_model_path()): - convert_to_halfp() - - if not output_name: - output_name = f'{settings["name"]}.yaml' - + if settings['resume_state']: + settings['source_model'] = f"# pretrain_model_gpt: {settings['source_model']}" + settings['resume_state'] = f"resume_state: {settings['resume_state']}'" + else: + settings['source_model'] = f"pretrain_model_gpt: {settings['source_model']}" + settings['resume_state'] = f"# resume_state: {settings['resume_state']}'" with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f: yaml = f.read() @@ -1436,11 +1417,13 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig continue yaml = yaml.replace(f"${{{k}}}", str(settings[k])) - outfile = f'./training/{output_name}' - with open(outfile, 'w', encoding="utf-8") as f: + outyaml = f'./training/{settings["voice"]}/train.yaml' + with open(outyaml, 'w', encoding="utf-8") as f: f.write(yaml) + - return f"Training settings saved to: {outfile}" + messages.append(f"Saved training output to: {outyaml}") + return settings, messages def import_voices(files, saveAs=None, progress=None): global args @@ -1524,10 +1507,10 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False): additionals = sorted([f'{dir}/{d}' for d in os.listdir(dir) if d[-4:] == ".pth" ]) found = [] for training in os.listdir(f'./training/'): - if not os.path.isdir(f'./training/{training}/') or not os.path.isdir(f'./training/{training}/models/'): + if not os.path.isdir(f'./training/{training}/') or not os.path.isdir(f'./training/{training}/finetunes/') or not os.path.isdir(f'./training/{training}/finetunes/models/'): continue - models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ]) - found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ] + models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/finetunes/models/') if d[-8:] == "_gpt.pth" ]) + found = found + [ f'./training/{training}/finetunes/models/{d}_gpt.pth' for d in models ] if len(found) > 0 or len(additionals) > 0: base = ["auto"] + base @@ -1545,10 +1528,10 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False): return res def get_dataset_list(dir="./training/"): - return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ]) + return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ]) def get_training_list(dir="./training/"): - return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) + return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) def do_gc(): gc.collect() @@ -1734,35 +1717,38 @@ def setup_args(): return args -def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, autocalculate_voice_chunk_duration_size, output_volume, autoregressive_model, vocoder_model, whisper_backend, whisper_model, training_default_halfp, training_default_bnb ): +def update_args( **kwargs ): global args - args.listen = listen - args.share = share - args.check_for_updates = check_for_updates - args.models_from_local_only = models_from_local_only - args.low_vram = low_vram - args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents - args.defer_tts_load = defer_tts_load - args.prune_nonfinal_outputs = prune_nonfinal_outputs - args.device_override = device_override - args.sample_batch_size = sample_batch_size - args.embed_output_metadata = embed_output_metadata - args.latents_lean_and_mean = latents_lean_and_mean - args.voice_fixer = voice_fixer - args.voice_fixer_use_cuda = voice_fixer_use_cuda - args.concurrency_count = concurrency_count - args.output_sample_rate = 44000 - args.autocalculate_voice_chunk_duration_size = autocalculate_voice_chunk_duration_size - args.output_volume = output_volume - - args.autoregressive_model = autoregressive_model - args.vocoder_model = vocoder_model - args.whisper_backend = whisper_backend - args.whisper_model = whisper_model + settings = {} + settings.update(kwargs) - args.training_default_halfp = training_default_halfp - args.training_default_bnb = training_default_bnb + args.listen = settings['listen'] + args.share = settings['share'] + args.check_for_updates = settings['check_for_updates'] + args.models_from_local_only = settings['models_from_local_only'] + args.low_vram = settings['low_vram'] + args.force_cpu_for_conditioning_latents = settings['force_cpu_for_conditioning_latents'] + args.defer_tts_load = settings['defer_tts_load'] + args.prune_nonfinal_outputs = settings['prune_nonfinal_outputs'] + args.device_override = settings['device_override'] + args.sample_batch_size = settings['sample_batch_size'] + args.embed_output_metadata = settings['embed_output_metadata'] + args.latents_lean_and_mean = settings['latents_lean_and_mean'] + args.voice_fixer = settings['voice_fixer'] + args.voice_fixer_use_cuda = settings['voice_fixer_use_cuda'] + args.concurrency_count = settings['concurrency_count'] + args.output_sample_rate = 44000 + args.autocalculate_voice_chunk_duration_size = settings['autocalculate_voice_chunk_duration_size'] + args.output_volume = settings['output_volume'] + + args.autoregressive_model = settings['autoregressive_model'] + args.vocoder_model = settings['vocoder_model'] + args.whisper_backend = settings['whisper_backend'] + args.whisper_model = settings['whisper_model'] + + args.training_default_halfp = settings['training_default_halfp'] + args.training_default_bnb = settings['training_default_bnb'] save_args_settings() @@ -1801,37 +1787,49 @@ def save_args_settings(): with open(f'./config/exec.json', 'w', encoding="utf-8") as f: f.write(json.dumps(settings, indent='\t') ) - +# super kludgy )`; +def set_generate_settings_arg_order(args): + global GENERATE_SETTINGS_ARGS + GENERATE_SETTINGS_ARGS = args def import_generate_settings(file="./config/generate.json"): + global GENERATE_SETTINGS_ARGS + + defaults = { + 'text': None, + 'delimiter': None, + 'emotion': None, + 'prompt': None, + 'voice': None, + 'mic_audio': None, + 'voice_latents_chunks': None, + 'candidates': None, + 'seed': None, + 'num_autoregressive_samples': 16, + 'diffusion_iterations': 30, + 'temperature': 0.8, + 'diffusion_sampler': "DDIM", + 'breathing_room': 8 , + 'cvvp_weight': 0.0, + 'top_p': 0.8, + 'diffusion_temperature': 1.0, + 'length_penalty': 1.0, + 'repetition_penalty': 2.0, + 'cond_free_k': 2.0, + 'experimentals': None, + } + settings, _ = read_generate_settings(file, read_latents=False) - if settings is None: - return None + res = [] + if GENERATE_SETTINGS_ARGS is not None: + for k in GENERATE_SETTINGS_ARGS: + res.append(defaults[k] if not settings or settings[k] is None else settings[k]) + else: + for k in defaults: + res.append(defaults[k] if not settings or settings[k] is None else settings[k]) - return ( - None if 'text' not in settings else settings['text'], - None if 'delimiter' not in settings else settings['delimiter'], - None if 'emotion' not in settings else settings['emotion'], - None if 'prompt' not in settings else settings['prompt'], - None if 'voice' not in settings else settings['voice'], - None, - None, - None if 'seed' not in settings else settings['seed'], - None if 'candidates' not in settings else settings['candidates'], - None if 'num_autoregressive_samples' not in settings else settings['num_autoregressive_samples'], - None if 'diffusion_iterations' not in settings else settings['diffusion_iterations'], - 0.8 if 'temperature' not in settings else settings['temperature'], - "DDIM" if 'diffusion_sampler' not in settings else settings['diffusion_sampler'], - 8 if 'breathing_room' not in settings else settings['breathing_room'], - 0.0 if 'cvvp_weight' not in settings else settings['cvvp_weight'], - 0.8 if 'top_p' not in settings else settings['top_p'], - 1.0 if 'diffusion_temperature' not in settings else settings['diffusion_temperature'], - 1.0 if 'length_penalty' not in settings else settings['length_penalty'], - 2.0 if 'repetition_penalty' not in settings else settings['repetition_penalty'], - 2.0 if 'cond_free_k' not in settings else settings['cond_free_k'], - None if 'experimentals' not in settings else settings['experimentals'], - ) + return tuple(res) def reset_generation_settings(): @@ -1955,10 +1953,10 @@ def deduce_autoregressive_model(voice=None): voice = get_current_voice() if voice: - dir = f'./training/{voice}-finetune/models/' - if os.path.exists(f'./training/finetunes/{voice}.pth'): - return f'./training/finetunes/{voice}.pth' + if os.path.exists(f'./models/finetunes/{voice}.pth'): + return f'./models/finetunes/{voice}.pth' + dir = f'./training/{voice}/finetune/models/' if os.path.isdir(dir): counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ]) names = [ f'{dir}/{d}_gpt.pth' for d in counts ] diff --git a/src/webui.py b/src/webui.py index c4e4a6e..5429394 100755 --- a/src/webui.py +++ b/src/webui.py @@ -4,6 +4,7 @@ import time import json import base64 import re +import inspect import urllib.request import torch @@ -22,7 +23,38 @@ from utils import * args = setup_args() -def run_generation( +GENERATE_SETTINGS = {} +TRANSCRIBE_SETTINGS = {} +EXEC_SETTINGS = {} +TRAINING_SETTINGS = {} + +PRESETS = { + 'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False}, + 'Fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 80}, + 'Standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200}, + 'High Quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400}, +} + +HISTORY_HEADERS = { + "Name": "", + "Samples": "num_autoregressive_samples", + "Iterations": "diffusion_iterations", + "Temp.": "temperature", + "Sampler": "diffusion_sampler", + "CVVP": "cvvp_weight", + "Top P": "top_p", + "Diff. Temp.": "diffusion_temperature", + "Len Pen": "length_penalty", + "Rep Pen": "repetition_penalty", + "Cond-Free K": "cond_free_k", + "Time": "time", + "Datetime": "datetime", + "Model": "model", + "Model Hash": "model_hash", +} + +# can't use *args OR **kwargs if I want to retain the ability to use progress +def generate_proxy( text, delimiter, emotion, @@ -30,8 +62,8 @@ def run_generation( voice, mic_audio, voice_latents_chunks, - seed, candidates, + seed, num_autoregressive_samples, diffusion_iterations, temperature, @@ -43,47 +75,20 @@ def run_generation( length_penalty, repetition_penalty, cond_free_k, - experimental_checkboxes, + experimentals, progress=gr.Progress(track_tqdm=True) ): - if not text: - raise gr.Error("Please provide text.") - if not voice: - raise gr.Error("Please provide a voice.") + kwargs = locals() try: - sample, outputs, stats = generate( - text=text, - delimiter=delimiter, - emotion=emotion, - prompt=prompt, - voice=voice, - mic_audio=mic_audio, - voice_latents_chunks=voice_latents_chunks, - seed=seed, - candidates=candidates, - num_autoregressive_samples=num_autoregressive_samples, - diffusion_iterations=diffusion_iterations, - temperature=temperature, - diffusion_sampler=diffusion_sampler, - breathing_room=breathing_room, - cvvp_weight=cvvp_weight, - top_p=top_p, - diffusion_temperature=diffusion_temperature, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - cond_free_k=cond_free_k, - experimental_checkboxes=experimental_checkboxes, - progress=progress - ) + sample, outputs, stats = generate(**kwargs) except Exception as e: message = str(e) if message == "Kill signal detected": unload_tts() - raise gr.Error(message) + raise e - return ( outputs[0], gr.update(value=sample, visible=sample is not None), @@ -91,14 +96,8 @@ def run_generation( gr.update(value=stats, visible=True), ) + def update_presets(value): - PRESETS = { - 'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False}, - 'Fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 80}, - 'Standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200}, - 'High Quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400}, - } - if value in PRESETS: preset = PRESETS[value] return (gr.update(value=preset['num_autoregressive_samples']), gr.update(value=preset['diffusion_iterations'])) @@ -117,24 +116,6 @@ def get_training_configs(): def update_training_configs(): return gr.update(choices=get_training_list()) -history_headers = { - "Name": "", - "Samples": "num_autoregressive_samples", - "Iterations": "diffusion_iterations", - "Temp.": "temperature", - "Sampler": "diffusion_sampler", - "CVVP": "cvvp_weight", - "Top P": "top_p", - "Diff. Temp.": "diffusion_temperature", - "Len Pen": "length_penalty", - "Rep Pen": "repetition_penalty", - "Cond-Free K": "cond_free_k", - "Time": "time", - "Datetime": "datetime", - "Model": "model", - "Model Hash": "model_hash", -} - def history_view_results( voice ): results = [] files = [] @@ -148,10 +129,10 @@ def history_view_results( voice ): continue values = [] - for k in history_headers: + for k in HISTORY_HEADERS: v = file if k != "Name": - v = metadata[history_headers[k]] if history_headers[k] in metadata else '?' + v = metadata[HISTORY_HEADERS[k]] if HISTORY_HEADERS[k] in metadata else '?' values.append(v) @@ -193,181 +174,55 @@ def read_generate_settings_proxy(file, saveAs='.temp'): def prepare_dataset_proxy( voice, language, skip_existings, progress=gr.Progress(track_tqdm=True) ): return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress ) -def optimize_training_settings_proxy( *args, **kwargs ): - tup = optimize_training_settings(*args, **kwargs) +def update_args_proxy( *args ): + kwargs = {} + keys = list(EXEC_SETTINGS.keys()) + for i in range(len(args)): + k = keys[i] + v = args[i] + kwargs[k] = v - return ( - gr.update(value=tup[0]), - gr.update(value=tup[1]), - gr.update(value=tup[2]), - gr.update(value=tup[3]), - gr.update(value=tup[4]), - gr.update(value=tup[5]), - gr.update(value=tup[6]), - gr.update(value=tup[7]), - gr.update(value=tup[8]), - "\n".join(tup[9]) - ) + update_args(**kwargs) +def optimize_training_settings_proxy( *args ): + kwargs = {} + keys = list(TRAINING_SETTINGS.keys()) + for i in range(len(args)): + k = keys[i] + v = args[i] + kwargs[k] = v + + settings, messages = optimize_training_settings(**kwargs) + output = list(settings.values()) + return output[:-1] + ["\n".join(messages)] def import_training_settings_proxy( voice ): - indir = f'./training/{voice}/' - outdir = f'./training/{voice}-finetune/' - - in_config_path = f"{indir}/train.yaml" - out_config_path = None - out_configs = [] - if os.path.isdir(outdir): - out_configs = sorted([d[:-5] for d in os.listdir(outdir) if d[-5:] == ".yaml" ]) - if len(out_configs) > 0: - out_config_path = f'{outdir}/{out_configs[-1]}.yaml' - - config_path = out_config_path if out_config_path else in_config_path - messages = [] - with open(config_path, 'r') as file: - config = yaml.safe_load(file) - messages.append(f"Importing from: {config_path}") - - dataset_path = f"./training/{voice}/train.txt" - with open(dataset_path, 'r', encoding="utf-8") as f: - lines = len(f.readlines()) - messages.append(f"Basing epoch size to {lines} lines") - - batch_size = config['datasets']['train']['batch_size'] - gradient_accumulation_size = config['train']['mega_batch_factor'] - - iterations = config['train']['niter'] - steps_per_iteration = int(lines / batch_size) - epochs = int(iterations / steps_per_iteration) - - - learning_rate = config['steps']['gpt_train']['optimizer_params']['lr'] - text_ce_lr_weight = config['steps']['gpt_train']['losses']['text_ce']['weight'] - learning_rate_schedule = [ int(x / steps_per_iteration) for x in config['train']['gen_lr_steps'] ] - - - print_rate = int(config['logger']['print_freq'] / steps_per_iteration) - save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration) - validation_rate = int(config['train']['val_freq'] / steps_per_iteration) - - half_p = config['fp16'] - bnb = True - - statedir = f'{outdir}/training_state/' - resumes = [] - resume_path = None - source_model = get_halfp_model_path() if half_p else get_model_path('autoregressive.pth') - - if "pretrain_model_gpt" in config['path']: - source_model = config['path']['pretrain_model_gpt'] - elif "resume_state" in config['path']: - resume_path = config['path']['resume_state'] + injson = f'./training/{voice}/train.json' + statedir = f'./training/{voice}/training_state/' + with open(injson, 'r', encoding="utf-8") as f: + settings = json.loads(f.read()) if os.path.isdir(statedir): resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ]) - if len(resumes) > 0: - resume_path = f'{statedir}/{resumes[-1]}.state' - messages.append(f"Latest resume found: {resume_path}") + if len(resumes) > 0: + settings['resume_state'] = f'{statedir}/{resumes[-1]}.state' + messages.append(f"Found most recent training state: {settings['resume_state']}") + output = list(settings.values()) + messages.append(f"Imported training settings: {injson}") + return output[:-1] + ["\n".join(messages)] +def save_training_settings_proxy( *args ): + kwargs = {} + keys = list(TRAINING_SETTINGS.keys()) + for i in range(len(args)): + k = keys[i] + v = args[i] + kwargs[k] = v - if "ext" in config and "bitsandbytes" in config["ext"]: - bnb = config["ext"]["bitsandbytes"] - - workers = config['datasets']['train']['n_workers'] - - messages = "\n".join(messages) - - return ( - epochs, - learning_rate, - text_ce_lr_weight, - learning_rate_schedule, - batch_size, - gradient_accumulation_size, - print_rate, - save_rate, - validation_rate, - resume_path, - half_p, - bnb, - workers, - source_model, - messages - ) - - -def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ): - name = f"{voice}-finetune" - dataset_name = f"{voice}-train" - dataset_path = f"./training/{voice}/train.txt" - validation_name = f"{voice}-val" - validation_path = f"./training/{voice}/validation.txt" - - with open(dataset_path, 'r', encoding="utf-8") as f: - lines = len(f.readlines()) - - messages = [] - - iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size) - messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps") - - print_rate = int(print_rate * iterations / epochs) - save_rate = int(save_rate * iterations / epochs) - validation_rate = int(validation_rate * iterations / epochs) - - validation_batch_size = int(batch_size / gradient_accumulation_size) - - if iterations % save_rate != 0: - adjustment = int(iterations / save_rate) * save_rate - messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {iterations} => {adjustment}") - iterations = adjustment - - if not os.path.exists(validation_path): - validation_rate = iterations - validation_path = dataset_path - messages.append("Validation not found, disabling validation...") - else: - with open(validation_path, 'r', encoding="utf-8") as f: - validation_lines = len(f.readlines()) - - - if validation_lines < validation_batch_size: - validation_batch_size = validation_lines - messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}") - - if not learning_rate_schedule: - learning_rate_schedule = EPOCH_SCHEDULE - elif isinstance(learning_rate_schedule,str): - learning_rate_schedule = json.loads(learning_rate_schedule) - - learning_rate_schedule = schedule_learning_rate( iterations / epochs, learning_rate_schedule ) - - messages.append(save_training_settings( - iterations=iterations, - batch_size=batch_size, - learning_rate=learning_rate, - text_ce_lr_weight=text_ce_lr_weight, - learning_rate_schedule=learning_rate_schedule, - gradient_accumulation_size=gradient_accumulation_size, - print_rate=print_rate, - save_rate=save_rate, - validation_rate=validation_rate, - name=name, - dataset_name=dataset_name, - dataset_path=dataset_path, - validation_name=validation_name, - validation_path=validation_path, - validation_batch_size=validation_batch_size, - output_name=f"{voice}/train.yaml", - resume_path=resume_path, - half_p=half_p, - bnb=bnb, - workers=workers, - source_model=source_model, - )) + settings, messages = save_training_settings(**kwargs) return "\n".join(messages) def update_voices(): @@ -406,60 +261,68 @@ def setup_gradio(): autoregressive_models = get_autoregressive_models() dataset_list = get_dataset_list() + GENERATE_SETTINGS_ARGS = list(inspect.signature(generate_proxy).parameters.keys())[:-1] + for i in range(len(GENERATE_SETTINGS_ARGS)): + arg = GENERATE_SETTINGS_ARGS[i] + GENERATE_SETTINGS[arg] = None + set_generate_settings_arg_order(GENERATE_SETTINGS_ARGS) + with gr.Blocks() as ui: with gr.Tab("Generate"): with gr.Row(): with gr.Column(): - text = gr.Textbox(lines=4, label="Input Prompt") + GENERATE_SETTINGS["text"] = gr.Textbox(lines=4, label="Input Prompt") with gr.Row(): with gr.Column(): - delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n") + GENERATE_SETTINGS["delimiter"] = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n") - emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True ) - prompt = gr.Textbox(lines=1, label="Custom Emotion") - voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit - mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False ) - voice_latents_chunks = gr.Number(label="Voice Chunks", precision=0, value=0) + GENERATE_SETTINGS["emotion"] = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True ) + GENERATE_SETTINGS["prompt"] = gr.Textbox(lines=1, label="Custom Emotion", visible=False) + GENERATE_SETTINGS["voice"] = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit + GENERATE_SETTINGS["mic_audio"] = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False ) + GENERATE_SETTINGS["voice_latents_chunks"] = gr.Number(label="Voice Chunks", precision=0, value=0) with gr.Row(): refresh_voices = gr.Button(value="Refresh Voice List") recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents") - voice.change( + GENERATE_SETTINGS["voice"].change( fn=update_baseline_for_latents_chunks, - inputs=voice, - outputs=voice_latents_chunks + inputs=GENERATE_SETTINGS["voice"], + outputs=GENERATE_SETTINGS["voice_latents_chunks"] ) - voice.change( + GENERATE_SETTINGS["voice"].change( fn=lambda value: gr.update(visible=value == "microphone"), - inputs=voice, - outputs=mic_audio, + inputs=GENERATE_SETTINGS["voice"], + outputs=GENERATE_SETTINGS["mic_audio"], ) with gr.Column(): - candidates = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates") - seed = gr.Number(value=0, precision=0, label="Seed") + GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates") + GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed") preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" ) - num_autoregressive_samples = gr.Slider(value=128, minimum=2, maximum=512, step=1, label="Samples") - diffusion_iterations = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations") - temperature = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature") + GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=128, minimum=2, maximum=512, step=1, label="Samples") + GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations") + + GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature") + show_experimental_settings = gr.Checkbox(label="Show Experimental Settings") reset_generation_settings_button = gr.Button(value="Reset to Default") with gr.Column(visible=False) as col: experimental_column = col - experimental_checkboxes = gr.CheckboxGroup(["Half Precision", "Conditioning-Free"], value=["Conditioning-Free"], label="Experimental Flags") - breathing_room = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size") - diffusion_sampler = gr.Radio( + GENERATE_SETTINGS["experimentals"] = gr.CheckboxGroup(["Half Precision", "Conditioning-Free"], value=["Conditioning-Free"], label="Experimental Flags") + GENERATE_SETTINGS["breathing_room"] = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size") + GENERATE_SETTINGS["diffusion_sampler"] = gr.Radio( ["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"], value="DDIM", label="Diffusion Samplers", type="value" ) - cvvp_weight = gr.Slider(value=0, minimum=0, maximum=1, label="CVVP Weight") - top_p = gr.Slider(value=0.8, minimum=0, maximum=1, label="Top P") - diffusion_temperature = gr.Slider(value=1.0, minimum=0, maximum=1, label="Diffusion Temperature") - length_penalty = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty") - repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty") - cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K") + GENERATE_SETTINGS["cvvp_weight"] = gr.Slider(value=0, minimum=0, maximum=1, label="CVVP Weight") + GENERATE_SETTINGS["top_p"] = gr.Slider(value=0.8, minimum=0, maximum=1, label="Top P") + GENERATE_SETTINGS["diffusion_temperature"] = gr.Slider(value=1.0, minimum=0, maximum=1, label="Diffusion Temperature") + GENERATE_SETTINGS["length_penalty"] = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty") + GENERATE_SETTINGS["repetition_penalty"] = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty") + GENERATE_SETTINGS["cond_free_k"] = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K") with gr.Column(): with gr.Row(): submit = gr.Button(value="Generate") @@ -483,7 +346,7 @@ def setup_gradio(): with gr.Tab("History"): with gr.Row(): with gr.Column(): - history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys())) + history_info = gr.Dataframe(label="Results", headers=list(HISTORY_HEADERS.keys())) with gr.Row(): with gr.Column(): history_voices = gr.Dropdown(choices=result_voices, label="Voice", type="value", value=result_voices[0] if len(result_voices) > 0 else "") @@ -521,51 +384,40 @@ def setup_gradio(): with gr.Tab("Generate Configuration"): with gr.Row(): with gr.Column(): - training_settings = [ - gr.Number(label="Epochs", value=500, precision=0), - ] + TRAINING_SETTINGS["epochs"] = gr.Number(label="Epochs", value=500, precision=0) with gr.Row(): with gr.Column(): - training_settings = training_settings + [ - gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6), - gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1), - ] - training_settings = training_settings + [ - gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)), - ] + TRAINING_SETTINGS["learning_rate"] = gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6) + TRAINING_SETTINGS["text_ce_lr_weight"] = gr.Slider(label="Text_CE LR Ratio", value=0.01, minimum=0, maximum=1) + + TRAINING_SETTINGS["learning_rate_schedule"] = gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)) with gr.Row(): - training_settings = training_settings + [ - gr.Number(label="Batch Size", value=128, precision=0), - gr.Number(label="Gradient Accumulation Size", value=4, precision=0), - ] + TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0) + TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0) with gr.Row(): - training_settings = training_settings + [ - gr.Number(label="Print Frequency (in epochs)", value=5, precision=0), - gr.Number(label="Save Frequency (in epochs)", value=5, precision=0), - gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0), - ] - training_settings = training_settings + [ - gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"), - ] + TRAINING_SETTINGS["print_rate"] = gr.Number(label="Print Frequency (in epochs)", value=5, precision=0) + TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0) + TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0) with gr.Row(): - training_halfp = gr.Checkbox(label="Half Precision", value=args.training_default_halfp) - training_bnb = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb) + TRAINING_SETTINGS["half_p"] = gr.Checkbox(label="Half Precision", value=args.training_default_halfp) + TRAINING_SETTINGS["bitsandbytes"] = gr.Checkbox(label="BitsAndBytes", value=args.training_default_bnb) - training_workers = gr.Number(label="Worker Processes", value=2, precision=0) - - source_model = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] ) - dataset_list_dropdown = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" ) - training_settings = training_settings + [ training_halfp, training_bnb, training_workers, source_model, dataset_list_dropdown ] + TRAINING_SETTINGS["workers"] = gr.Number(label="Worker Processes", value=2, precision=0) + TRAINING_SETTINGS["gpus"] = gr.Number(label="GPUs", value=get_device_count(), precision=0) + TRAINING_SETTINGS["source_model"] = gr.Dropdown( choices=autoregressive_models, label="Source Model", type="value", value=autoregressive_models[0] ) + TRAINING_SETTINGS["resume_state"] = gr.Textbox(label="Resume State Path", placeholder="./training/${voice}/training_state/${last_state}.state") + + TRAINING_SETTINGS["voice"] = gr.Dropdown( choices=dataset_list, label="Dataset", type="value", value=dataset_list[0] if len(dataset_list) else "" ) with gr.Row(): - refresh_dataset_list = gr.Button(value="Refresh Dataset List") - import_dataset_button = gr.Button(value="Reuse/Import Dataset") + training_refresh_dataset = gr.Button(value="Refresh Dataset List") + training_import_settings = gr.Button(value="Reuse/Import Dataset") with gr.Column(): - save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) + training_configuration_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) with gr.Row(): - optimize_yaml_button = gr.Button(value="Validate Training Configuration") - save_yaml_button = gr.Button(value="Save Training Configuration") + training_optimize_configuration = gr.Button(value="Validate Training Configuration") + training_save_configuration = gr.Button(value="Save Training Configuration") with gr.Tab("Run Training"): with gr.Row(): with gr.Column(): @@ -588,9 +440,7 @@ def setup_gradio(): training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) - with gr.Row(): - training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) - training_gpu_count = gr.Number(label="GPUs", value=get_device_count()) + training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) with gr.Row(): start_training_button = gr.Button(value="Train") stop_training_button = gr.Button(value="Stop") @@ -599,43 +449,40 @@ def setup_gradio(): with gr.Row(): exec_inputs = [] with gr.Column(): - exec_inputs = exec_inputs + [ - gr.Textbox(label="Listen", value=args.listen, placeholder="127.0.0.1:7860/"), - gr.Checkbox(label="Public Share Gradio", value=args.share), - gr.Checkbox(label="Check For Updates", value=args.check_for_updates), - gr.Checkbox(label="Only Load Models Locally", value=args.models_from_local_only), - gr.Checkbox(label="Low VRAM", value=args.low_vram), - gr.Checkbox(label="Embed Output Metadata", value=args.embed_output_metadata), - gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean), - gr.Checkbox(label="Use Voice Fixer on Generated Output", value=args.voice_fixer), - gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda), - gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents), - gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load), - gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs), - gr.Textbox(label="Device Override", value=args.device_override), - ] + EXEC_SETTINGS['listen'] = gr.Textbox(label="Listen", value=args.listen, placeholder="127.0.0.1:7860/") + EXEC_SETTINGS['share'] = gr.Checkbox(label="Public Share Gradio", value=args.share) + EXEC_SETTINGS['check_for_updates'] = gr.Checkbox(label="Check For Updates", value=args.check_for_updates) + EXEC_SETTINGS['models_from_local_only'] = gr.Checkbox(label="Only Load Models Locally", value=args.models_from_local_only) + EXEC_SETTINGS['low_vram'] = gr.Checkbox(label="Low VRAM", value=args.low_vram) + EXEC_SETTINGS['embed_output_metadata'] = gr.Checkbox(label="Embed Output Metadata", value=args.embed_output_metadata) + EXEC_SETTINGS['latents_lean_and_mean'] = gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean) + EXEC_SETTINGS['voice_fixer'] = gr.Checkbox(label="Use Voice Fixer on Generated Output", value=args.voice_fixer) + EXEC_SETTINGS['voice_fixer_use_cuda'] = gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda) + EXEC_SETTINGS['force_cpu_for_conditioning_latents'] = gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents) + EXEC_SETTINGS['defer_tts_load'] = gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load) + EXEC_SETTINGS['prune_nonfinal_outputs'] = gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs) + EXEC_SETTINGS['device_override'] = gr.Textbox(label="Device Override", value=args.device_override) with gr.Column(): - exec_inputs = exec_inputs + [ - gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size), - gr.Number(label="Gradio Concurrency Count", precision=0, value=args.concurrency_count), - gr.Number(label="Auto-Calculate Voice Chunk Duration (in seconds)", precision=0, value=args.autocalculate_voice_chunk_duration_size), - gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume), - ] + EXEC_SETTINGS['sample_batch_size'] = gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size) + EXEC_SETTINGS['concurrency_count'] = gr.Number(label="Gradio Concurrency Count", precision=0, value=args.concurrency_count) + EXEC_SETTINGS['autocalculate_voice_chunk_duration_size'] = gr.Number(label="Auto-Calculate Voice Chunk Duration (in seconds)", precision=0, value=args.autocalculate_voice_chunk_duration_size) + EXEC_SETTINGS['output_volume'] = gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume) - autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) + EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) - vocoder_models = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1]) - whisper_backend = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend) - whisper_model_dropdown = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model) + EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1]) + EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend) + EXEC_SETTINGS['whisper_model'] = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model) - exec_inputs = exec_inputs + [ autoregressive_model_dropdown, vocoder_models, whisper_backend, whisper_model_dropdown, training_halfp, training_bnb ] + EXEC_SETTINGS['training_default_halfp'] = TRAINING_SETTINGS['half_p'] + EXEC_SETTINGS['training_default_bnb'] = TRAINING_SETTINGS['bitsandbytes'] with gr.Row(): autoregressive_models_update_button = gr.Button(value="Refresh Model List") gr.Button(value="Check for Updates").click(check_for_updates) gr.Button(value="(Re)Load TTS").click( reload_tts, - inputs=autoregressive_model_dropdown, + inputs=EXEC_SETTINGS['autoregressive_model'], outputs=None ) # kill_button = gr.Button(value="Close UI") @@ -648,49 +495,26 @@ def setup_gradio(): autoregressive_models_update_button.click( update_model_list_proxy, - inputs=autoregressive_model_dropdown, - outputs=autoregressive_model_dropdown, + inputs=EXEC_SETTINGS['autoregressive_model'], + outputs=EXEC_SETTINGS['autoregressive_model'], ) - for i in exec_inputs: - i.change( fn=update_args, inputs=exec_inputs ) + exec_inputs = list(EXEC_SETTINGS.values()) + for k in EXEC_SETTINGS: + EXEC_SETTINGS[k].change( fn=update_args_proxy, inputs=exec_inputs ) - autoregressive_model_dropdown.change( + EXEC_SETTINGS['autoregressive_model'].change( fn=update_autoregressive_model, - inputs=autoregressive_model_dropdown, + inputs=EXEC_SETTINGS['autoregressive_model'], outputs=None ) - vocoder_models.change( + EXEC_SETTINGS['vocoder_model'].change( fn=update_vocoder_model, - inputs=vocoder_models, + inputs=EXEC_SETTINGS['vocoder_model'], outputs=None ) - input_settings = [ - text, - delimiter, - emotion, - prompt, - voice, - mic_audio, - voice_latents_chunks, - seed, - candidates, - num_autoregressive_samples, - diffusion_iterations, - temperature, - diffusion_sampler, - breathing_room, - cvvp_weight, - top_p, - diffusion_temperature, - length_penalty, - repetition_penalty, - cond_free_k, - experimental_checkboxes, - ] - history_voices.change( fn=history_view_results, inputs=history_voices, @@ -734,45 +558,46 @@ def setup_gradio(): preset.change(fn=update_presets, inputs=preset, outputs=[ - num_autoregressive_samples, - diffusion_iterations, + GENERATE_SETTINGS['num_autoregressive_samples'], + GENERATE_SETTINGS['diffusion_iterations'], ], ) recompute_voice_latents.click(compute_latents_proxy, inputs=[ - voice, - voice_latents_chunks, + GENERATE_SETTINGS['voice'], + GENERATE_SETTINGS['voice_latents_chunks'], ], - outputs=voice, + outputs=GENERATE_SETTINGS['voice'], ) - emotion.change( + GENERATE_SETTINGS['emotion'].change( fn=lambda value: gr.update(visible=value == "Custom"), - inputs=emotion, - outputs=prompt + inputs=GENERATE_SETTINGS['emotion'], + outputs=GENERATE_SETTINGS['prompt'] ) - mic_audio.change(fn=lambda value: gr.update(value="microphone"), - inputs=mic_audio, - outputs=voice + GENERATE_SETTINGS['mic_audio'].change(fn=lambda value: gr.update(value="microphone"), + inputs=GENERATE_SETTINGS['mic_audio'], + outputs=GENERATE_SETTINGS['voice'] ) refresh_voices.click(update_voices, inputs=None, outputs=[ - voice, + GENERATE_SETTINGS['voice'], dataset_settings[0], history_voices ] ) + generate_settings = list(GENERATE_SETTINGS.values()) submit.click( lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)), outputs=[source_sample, candidates_list, generation_results], ) - submit_event = submit.click(run_generation, - inputs=input_settings, + submit_event = submit.click(generate_proxy, + inputs=generate_settings, outputs=[output_audio, source_sample, candidates_list, generation_results], api_name="generate", ) @@ -780,13 +605,13 @@ def setup_gradio(): copy_button.click(import_generate_settings, inputs=audio_in, # JSON elements cannot be used as inputs - outputs=input_settings + outputs=generate_settings ) reset_generation_settings_button.click( fn=reset_generation_settings, inputs=None, - outputs=input_settings + outputs=generate_settings ) history_copy_settings_button.click(history_copy_settings, @@ -794,7 +619,7 @@ def setup_gradio(): history_voices, history_results_list, ], - outputs=input_settings + outputs=generate_settings ) refresh_configs.click( @@ -806,7 +631,6 @@ def setup_gradio(): inputs=[ training_configs, verbose_training, - training_gpu_count, training_keep_x_past_datasets, ], outputs=[ @@ -855,38 +679,28 @@ def setup_gradio(): ], outputs=prepare_dataset_output #console_output ) - refresh_dataset_list.click( + + training_refresh_dataset.click( lambda: gr.update(choices=get_dataset_list()), inputs=None, - outputs=dataset_list_dropdown, + outputs=TRAINING_SETTINGS["voice"], ) - optimize_yaml_button.click(optimize_training_settings_proxy, + training_settings = list(TRAINING_SETTINGS.values()) + training_optimize_configuration.click(optimize_training_settings_proxy, inputs=training_settings, - outputs=training_settings[1:10] + [save_yaml_output] #console_output + outputs=training_settings[:-1] + [training_configuration_output] #console_output ) - import_dataset_button.click(import_training_settings_proxy, - inputs=dataset_list_dropdown, - outputs=training_settings[:14] + [save_yaml_output] #console_output + training_import_settings.click(import_training_settings_proxy, + inputs=TRAINING_SETTINGS['voice'], + outputs=training_settings[:-1] + [training_configuration_output] #console_output ) - save_yaml_button.click(save_training_settings_proxy, + training_save_configuration.click(save_training_settings_proxy, inputs=training_settings, - outputs=save_yaml_output #console_output + outputs=training_configuration_output #console_output ) - """ - def kill_process(): - ui.close() - exit() - - kill_button.click( - kill_process, - inputs=None, - outputs=None - ) - """ - if os.path.isfile('./config/generate.json'): - ui.load(import_generate_settings, inputs=None, outputs=input_settings) + ui.load(import_generate_settings, inputs=None, outputs=generate_settings) if args.check_for_updates: ui.load(check_for_updates)