From f8249aa826afe94c6a65ab5002950fc125c4a68a Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 17 Feb 2023 03:05:27 +0000 Subject: [PATCH] tab to generate the training YAML --- README.md | 21 +++ src/train.py | 35 ++++ src/utils.py | 252 +++++++++++++++++++++++++- src/webui.py | 389 ++++++++++------------------------------ training/.template.yaml | 144 +++++++++++++++ 5 files changed, 543 insertions(+), 298 deletions(-) create mode 100755 src/train.py create mode 100755 training/.template.yaml diff --git a/README.md b/README.md index f93afab..31b68ec 100755 --- a/README.md +++ b/README.md @@ -221,6 +221,27 @@ If you want to reuse its generation settings, simply click `Copy Settings`. To import a voice, click `Import Voice`. Remember to click `Refresh Voice List` in the `Generate` panel afterwards, if it's a new voice. +### Training + +This tab will contain a collection of sub-tabs pertaining to training. + +#### Configuration + +This will generate the YAML necessary to feed into training. For now, you can set: +* `Batch Size`: size of batches for training, more batches = faster training, at the cost of higher VRAM. setting this to 1 will lead to problems +* `Learning Rate`: how large changes to training will be made, lower values = better over the long term, while higher values will fry a model fast. For fine-tuning, the default *should* be fine, but in the future, a learning rate scheduler would be better (have a higher learning rate initially, then step it down over enough steps/epochs) +* `Print Frequency`: how often to print (I assume) +* `Save Frequency`: how often to save checkpoints +* `Training Name`: name to save the configuration as, as well as the training script to create the folder under +* `Dataset Name`: **!**TODO**!**: fill +* `Dataset Path`: path to the input training text file. For LJSpeech-esque datasets, this is to a textfile formatted like: +``` +wavs/LJ001-0001.wav|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition +wavs/LJ001-0002.wav|in being comparatively modern.|in being comparatively modern. +``` +* `Validation Name`: **!**TODO**!**: fill +* `Validation Path`: path for the validation set, similar to the dataset. I'm not necessarily sure what to really use for this, so explicitly for testing, I just copied the training dataset text + ### Settings This tab (should) hold a bunch of other settings, from tunables that shouldn't be tampered with, to settings pertaining to the web UI itself. diff --git a/src/train.py b/src/train.py new file mode 100755 index 0000000..93bc7a6 --- /dev/null +++ b/src/train.py @@ -0,0 +1,35 @@ +import torch +import argparse + +from ..dlas.codes import * +from ..dlas.codes.utils import util, options as option + +parser = argparse.ArgumentParser() +parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml') +parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') +args = parser.parse_args() +opt = option.parse(args.opt, is_train=True) +if args.launcher != 'none': + # export CUDA_VISIBLE_DEVICES for running in distributed mode. + if 'gpu_ids' in opt.keys(): + gpu_list = ','.join(str(x) for x in opt['gpu_ids']) + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list + print('export CUDA_VISIBLE_DEVICES=' + gpu_list) +trainer = Trainer() + +#### distributed training settings +if args.launcher == 'none': # disabled distributed training + opt['dist'] = False + trainer.rank = -1 + if len(opt['gpu_ids']) == 1: + torch.cuda.set_device(opt['gpu_ids'][0]) + print('Disabled distributed training.') +else: + opt['dist'] = True + init_dist('nccl') + trainer.world_size = torch.distributed.get_world_size() + trainer.rank = torch.distributed.get_rank() + torch.cuda.set_device(torch.distributed.get_rank()) + +trainer.init(args.opt, opt, args.launcher) +trainer.do_training() \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 205a87b..47ab7dd 100755 --- a/src/utils.py +++ b/src/utils.py @@ -138,7 +138,7 @@ def generate( try: tts except NameError: - raise gr.Error("TTS is still initializing...") + raise Exception("TTS is still initializing...") if voice != "microphone": voices = [voice] @@ -147,7 +147,7 @@ def generate( if voice == "microphone": if mic_audio is None: - raise gr.Error("Please provide audio from mic when choosing `microphone` as a voice input") + raise Exception("Please provide audio from mic when choosing `microphone` as a voice input") mic = load_audio(mic_audio, tts.input_sample_rate) voice_samples, conditioning_latents = [mic], None elif voice == "random": @@ -431,4 +431,250 @@ def setup_tortoise(restart=False): print("Initializating TorToiSe...") tts = TextToSpeech(minor_optimizations=not args.low_vram) print("TorToiSe initialized, ready for generation.") - return tts \ No newline at end of file + return tts + +def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None ): + settings = { + "batch_size": batch_size if batch_size else 128, + "learning_rate": learning_rate if learning_rate else 1e-5, + "print_rate": print_rate if print_rate else 50, + "save_rate": save_rate if save_rate else 50, + "name": name if name else "finetune", + "dataset_name": dataset_name if dataset_name else "finetune", + "dataset_path": dataset_path if dataset_path else "./experiments/finetune/train.txt", + "validation_name": validation_name if validation_name else "finetune", + "validation_path": validation_path if validation_path else "./experiments/finetune/val.txt", + } + + with open(f'./training/.template.yaml', 'r', encoding="utf-8") as f: + yaml = f.read() + + for k in settings: + print(f"${{{k}}} => {settings[k]}") + yaml = yaml.replace(f"${{{k}}}", str(settings[k])) + + with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f: + f.write(yaml) + +def reset_generation_settings(): + with open(f'./config/generate.json', 'w', encoding="utf-8") as f: + f.write(json.dumps({}, indent='\t') ) + return import_generate_settings() + +def import_voice(file, saveAs = None): + global args + + j, latents = read_generate_settings(file, read_latents=True) + + if j is not None and saveAs is None: + saveAs = j['voice'] + if saveAs is None or saveAs == "": + raise Exception("Specify a voice name") + + outdir = f'{get_voice_dir()}/{saveAs}/' + os.makedirs(outdir, exist_ok=True) + if latents: + with open(f'{outdir}/cond_latents.pth', 'wb') as f: + f.write(latents) + latents = f'{outdir}/cond_latents.pth' + print(f"Imported latents to {latents}") + else: + filename = file.name + if filename[-4:] != ".wav": + raise Exception("Please convert to a WAV first") + + path = f"{outdir}/{os.path.basename(filename)}" + waveform, sampling_rate = torchaudio.load(filename) + + if args.voice_fixer: + # resample to best bandwidth since voicefixer will do it anyways through librosa + if sampling_rate != 44100: + print(f"Resampling imported voice sample: {path}") + resampler = torchaudio.transforms.Resample( + sampling_rate, + 44100, + lowpass_filter_width=16, + rolloff=0.85, + resampling_method="kaiser_window", + beta=8.555504641634386, + ) + waveform = resampler(waveform) + sampling_rate = 44100 + + torchaudio.save(path, waveform, sampling_rate) + + print(f"Running 'voicefixer' on voice sample: {path}") + voicefixer.restore( + input = path, + output = path, + cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, + #mode=mode, + ) + else: + torchaudio.save(path, waveform, sampling_rate) + + + print(f"Imported voice to {path}") + + +def import_generate_settings(file="./config/generate.json"): + settings, _ = read_generate_settings(file, read_latents=False) + + if settings is None: + return None + + return ( + None if 'text' not in settings else settings['text'], + None if 'delimiter' not in settings else settings['delimiter'], + None if 'emotion' not in settings else settings['emotion'], + None if 'prompt' not in settings else settings['prompt'], + None if 'voice' not in settings else settings['voice'], + None, + None, + None if 'seed' not in settings else settings['seed'], + None if 'candidates' not in settings else settings['candidates'], + None if 'num_autoregressive_samples' not in settings else settings['num_autoregressive_samples'], + None if 'diffusion_iterations' not in settings else settings['diffusion_iterations'], + 0.8 if 'temperature' not in settings else settings['temperature'], + "DDIM" if 'diffusion_sampler' not in settings else settings['diffusion_sampler'], + 8 if 'breathing_room' not in settings else settings['breathing_room'], + 0.0 if 'cvvp_weight' not in settings else settings['cvvp_weight'], + 0.8 if 'top_p' not in settings else settings['top_p'], + 1.0 if 'diffusion_temperature' not in settings else settings['diffusion_temperature'], + 1.0 if 'length_penalty' not in settings else settings['length_penalty'], + 2.0 if 'repetition_penalty' not in settings else settings['repetition_penalty'], + 2.0 if 'cond_free_k' not in settings else settings['cond_free_k'], + None if 'experimentals' not in settings else settings['experimentals'], + ) + +def curl(url): + try: + req = urllib.request.Request(url, headers={'User-Agent': 'Python'}) + conn = urllib.request.urlopen(req) + data = conn.read() + data = data.decode() + data = json.loads(data) + conn.close() + return data + except Exception as e: + print(e) + return None + +def check_for_updates(): + if not os.path.isfile('./.git/FETCH_HEAD'): + print("Cannot check for updates: not from a git repo") + return False + + with open(f'./.git/FETCH_HEAD', 'r', encoding="utf-8") as f: + head = f.read() + + match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head) + if match is None or len(match) == 0: + print("Cannot check for updates: cannot parse FETCH_HEAD") + return False + + match = match[0] + + local = match[0] + host = match[1] + owner = match[2] + repo = match[3] + + res = curl(f"https://{host}/api/v1/repos/{owner}/{repo}/branches/") #this only works for gitea instances + + if res is None or len(res) == 0: + print("Cannot check for updates: cannot fetch from remote") + return False + + remote = res[0]["commit"]["id"] + + if remote != local: + print(f"New version found: {local[:8]} => {remote[:8]}") + return True + + return False + +def reload_tts(): + global tts + del tts + tts = setup_tortoise(restart=True) + +def cancel_generate(): + tortoise.api.STOP_SIGNAL = True + +def get_voice_list(dir=get_voice_dir()): + os.makedirs(dir, exist_ok=True) + return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"] + +def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ): + global args + + args.listen = listen + args.share = share + args.check_for_updates = check_for_updates + args.models_from_local_only = models_from_local_only + args.low_vram = low_vram + args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents + args.device_override = device_override + args.sample_batch_size = sample_batch_size + args.embed_output_metadata = embed_output_metadata + args.latents_lean_and_mean = latents_lean_and_mean + args.voice_fixer = voice_fixer + args.voice_fixer_use_cuda = voice_fixer_use_cuda + args.concurrency_count = concurrency_count + args.output_sample_rate = output_sample_rate + args.output_volume = output_volume + + settings = { + 'listen': None if args.listen else args.listen, + 'share': args.share, + 'low-vram':args.low_vram, + 'check-for-updates':args.check_for_updates, + 'models-from-local-only':args.models_from_local_only, + 'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents, + 'device-override': args.device_override, + 'sample-batch-size': args.sample_batch_size, + 'embed-output-metadata': args.embed_output_metadata, + 'latents-lean-and-mean': args.latents_lean_and_mean, + 'voice-fixer': args.voice_fixer, + 'voice-fixer-use-cuda': args.voice_fixer_use_cuda, + 'concurrency-count': args.concurrency_count, + 'output-sample-rate': args.output_sample_rate, + 'output-volume': args.output_volume, + } + + with open(f'./config/exec.json', 'w', encoding="utf-8") as f: + f.write(json.dumps(settings, indent='\t') ) + +def read_generate_settings(file, read_latents=True, read_json=True): + j = None + latents = None + + if file is not None: + if hasattr(file, 'name'): + file = file.name + + if file[-4:] == ".wav": + metadata = music_tag.load_file(file) + if 'lyrics' in metadata: + j = json.loads(str(metadata['lyrics'])) + elif file[-5:] == ".json": + with open(file, 'r') as f: + j = json.load(f) + + if j is None: + print("No metadata found in audio file to read") + else: + if 'latents' in j: + if read_latents: + latents = base64.b64decode(j['latents']) + del j['latents'] + + + if "time" in j: + j["time"] = "{:.3f}".format(j["time"]) + + return ( + j, + latents, + ) \ No newline at end of file diff --git a/src/webui.py b/src/webui.py index 820223b..fb8ce58 100755 --- a/src/webui.py +++ b/src/webui.py @@ -21,6 +21,71 @@ from utils import * args = setup_args() +def run_generation( + text, + delimiter, + emotion, + prompt, + voice, + mic_audio, + voice_latents_chunks, + seed, + candidates, + num_autoregressive_samples, + diffusion_iterations, + temperature, + diffusion_sampler, + breathing_room, + cvvp_weight, + top_p, + diffusion_temperature, + length_penalty, + repetition_penalty, + cond_free_k, + experimental_checkboxes, + progress=gr.Progress(track_tqdm=True) +): + try: + sample, outputs, stats = generate( + text, + delimiter, + emotion, + prompt, + voice, + mic_audio, + voice_latents_chunks, + seed, + candidates, + num_autoregressive_samples, + diffusion_iterations, + temperature, + diffusion_sampler, + breathing_room, + cvvp_weight, + top_p, + diffusion_temperature, + length_penalty, + repetition_penalty, + cond_free_k, + experimental_checkboxes, + progress + ) + except Exception as e: + message = str(e) + if message == "Kill signal detected": + reload_tts() + + raise gr.Error(message) + + + return ( + outputs[0], + gr.update(value=sample, visible=sample is not None), + gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True), + gr.update(visible=len(outputs) > 1), + gr.update(value=stats, visible=True), + ) + def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): global tts global args @@ -58,230 +123,6 @@ def update_presets(value): else: return (gr.update(), gr.update()) -def read_generate_settings(file, read_latents=True, read_json=True): - j = None - latents = None - - if file is not None: - if hasattr(file, 'name'): - file = file.name - - if file[-4:] == ".wav": - metadata = music_tag.load_file(file) - if 'lyrics' in metadata: - j = json.loads(str(metadata['lyrics'])) - elif file[-5:] == ".json": - with open(file, 'r') as f: - j = json.load(f) - - if j is None: - gr.Error("No metadata found in audio file to read") - else: - if 'latents' in j: - if read_latents: - latents = base64.b64decode(j['latents']) - del j['latents'] - - - if "time" in j: - j["time"] = "{:.3f}".format(j["time"]) - - return ( - j, - latents, - ) - -def import_voice(file, saveAs = None): - global args - - j, latents = read_generate_settings(file, read_latents=True) - - if j is not None and saveAs is None: - saveAs = j['voice'] - if saveAs is None or saveAs == "": - raise gr.Error("Specify a voice name") - - outdir = f'{get_voice_dir()}/{saveAs}/' - os.makedirs(outdir, exist_ok=True) - if latents: - with open(f'{outdir}/cond_latents.pth', 'wb') as f: - f.write(latents) - latents = f'{outdir}/cond_latents.pth' - print(f"Imported latents to {latents}") - else: - filename = file.name - if filename[-4:] != ".wav": - raise gr.Error("Please convert to a WAV first") - - path = f"{outdir}/{os.path.basename(filename)}" - waveform, sampling_rate = torchaudio.load(filename) - - if args.voice_fixer: - # resample to best bandwidth since voicefixer will do it anyways through librosa - if sampling_rate != 44100: - print(f"Resampling imported voice sample: {path}") - resampler = torchaudio.transforms.Resample( - sampling_rate, - 44100, - lowpass_filter_width=16, - rolloff=0.85, - resampling_method="kaiser_window", - beta=8.555504641634386, - ) - waveform = resampler(waveform) - sampling_rate = 44100 - - torchaudio.save(path, waveform, sampling_rate) - - print(f"Running 'voicefixer' on voice sample: {path}") - voicefixer.restore( - input = path, - output = path, - cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, - #mode=mode, - ) - else: - torchaudio.save(path, waveform, sampling_rate) - - - print(f"Imported voice to {path}") - - -def import_generate_settings(file="./config/generate.json"): - settings, _ = read_generate_settings(file, read_latents=False) - - if settings is None: - return None - - return ( - None if 'text' not in settings else settings['text'], - None if 'delimiter' not in settings else settings['delimiter'], - None if 'emotion' not in settings else settings['emotion'], - None if 'prompt' not in settings else settings['prompt'], - None if 'voice' not in settings else settings['voice'], - None, - None, - None if 'seed' not in settings else settings['seed'], - None if 'candidates' not in settings else settings['candidates'], - None if 'num_autoregressive_samples' not in settings else settings['num_autoregressive_samples'], - None if 'diffusion_iterations' not in settings else settings['diffusion_iterations'], - 0.8 if 'temperature' not in settings else settings['temperature'], - "DDIM" if 'diffusion_sampler' not in settings else settings['diffusion_sampler'], - 8 if 'breathing_room' not in settings else settings['breathing_room'], - 0.0 if 'cvvp_weight' not in settings else settings['cvvp_weight'], - 0.8 if 'top_p' not in settings else settings['top_p'], - 1.0 if 'diffusion_temperature' not in settings else settings['diffusion_temperature'], - 1.0 if 'length_penalty' not in settings else settings['length_penalty'], - 2.0 if 'repetition_penalty' not in settings else settings['repetition_penalty'], - 2.0 if 'cond_free_k' not in settings else settings['cond_free_k'], - None if 'experimentals' not in settings else settings['experimentals'], - ) - -def curl(url): - try: - req = urllib.request.Request(url, headers={'User-Agent': 'Python'}) - conn = urllib.request.urlopen(req) - data = conn.read() - data = data.decode() - data = json.loads(data) - conn.close() - return data - except Exception as e: - print(e) - return None - -def check_for_updates(): - if not os.path.isfile('./.git/FETCH_HEAD'): - print("Cannot check for updates: not from a git repo") - return False - - with open(f'./.git/FETCH_HEAD', 'r', encoding="utf-8") as f: - head = f.read() - - match = re.findall(r"^([a-f0-9]+).+?https:\/\/(.+?)\/(.+?)\/(.+?)\n", head) - if match is None or len(match) == 0: - print("Cannot check for updates: cannot parse FETCH_HEAD") - return False - - match = match[0] - - local = match[0] - host = match[1] - owner = match[2] - repo = match[3] - - res = curl(f"https://{host}/api/v1/repos/{owner}/{repo}/branches/") #this only works for gitea instances - - if res is None or len(res) == 0: - print("Cannot check for updates: cannot fetch from remote") - return False - - remote = res[0]["commit"]["id"] - - if remote != local: - print(f"New version found: {local[:8]} => {remote[:8]}") - return True - - return False - -def reload_tts(): - global tts - del tts - tts = setup_tortoise(restart=True) - -def cancel_generate(): - tortoise.api.STOP_SIGNAL = True - -def get_voice_list(dir=get_voice_dir()): - os.makedirs(dir, exist_ok=True) - return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) + ["microphone", "random"] - -def update_voices(): - return ( - gr.Dropdown.update(choices=get_voice_list()), - gr.Dropdown.update(choices=get_voice_list("./results/")), - ) - -def export_exec_settings( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ): - global args - - args.listen = listen - args.share = share - args.check_for_updates = check_for_updates - args.models_from_local_only = models_from_local_only - args.low_vram = low_vram - args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents - args.device_override = device_override - args.sample_batch_size = sample_batch_size - args.embed_output_metadata = embed_output_metadata - args.latents_lean_and_mean = latents_lean_and_mean - args.voice_fixer = voice_fixer - args.voice_fixer_use_cuda = voice_fixer_use_cuda - args.concurrency_count = concurrency_count - args.output_sample_rate = output_sample_rate - args.output_volume = output_volume - - settings = { - 'listen': None if args.listen else args.listen, - 'share': args.share, - 'low-vram':args.low_vram, - 'check-for-updates':args.check_for_updates, - 'models-from-local-only':args.models_from_local_only, - 'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents, - 'device-override': args.device_override, - 'sample-batch-size': args.sample_batch_size, - 'embed-output-metadata': args.embed_output_metadata, - 'latents-lean-and-mean': args.latents_lean_and_mean, - 'voice-fixer': args.voice_fixer, - 'voice-fixer-use-cuda': args.voice_fixer_use_cuda, - 'concurrency-count': args.concurrency_count, - 'output-sample-rate': args.output_sample_rate, - 'output-volume': args.output_volume, - } - - with open(f'./config/exec.json', 'w', encoding="utf-8") as f: - f.write(json.dumps(settings, indent='\t') ) - def setup_gradio(): global args global ui @@ -528,6 +369,29 @@ def setup_gradio(): import_voice_name, ] ) + with gr.Tab("Training"): + with gr.Tab("Configuration"): + with gr.Row(): + with gr.Column(): + training_settings = [ + gr.Slider(label="Batch Size", value=128), + gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6), + gr.Number(label="Print Frequency", value=50), + gr.Number(label="Save Frequency", value=50), + ] + with gr.Column(): + training_settings = training_settings + [ + gr.Textbox(label="Training Name", placeholder="finetune"), + gr.Textbox(label="Dataset Name", placeholder="finetune"), + gr.Textbox(label="Dataset Path", placeholder="./experiments/finetune/train.txt"), + gr.Textbox(label="Validation Name", placeholder="finetune"), + gr.Textbox(label="Validation Path", placeholder="./experiments/finetune/val.txt"), + ] + save_yaml_button = gr.Button(value="Save Training Configuration") + save_yaml_button.click(save_training_settings, + inputs=training_settings, + outputs=None + ) with gr.Tab("Settings"): with gr.Row(): exec_inputs = [] @@ -586,71 +450,15 @@ def setup_gradio(): ] # YUCK - def run_generation( - text, - delimiter, - emotion, - prompt, - voice, - mic_audio, - voice_latents_chunks, - seed, - candidates, - num_autoregressive_samples, - diffusion_iterations, - temperature, - diffusion_sampler, - breathing_room, - cvvp_weight, - top_p, - diffusion_temperature, - length_penalty, - repetition_penalty, - cond_free_k, - experimental_checkboxes, - progress=gr.Progress(track_tqdm=True) - ): - try: - sample, outputs, stats = generate( - text, - delimiter, - emotion, - prompt, - voice, - mic_audio, - voice_latents_chunks, - seed, - candidates, - num_autoregressive_samples, - diffusion_iterations, - temperature, - diffusion_sampler, - breathing_room, - cvvp_weight, - top_p, - diffusion_temperature, - length_penalty, - repetition_penalty, - cond_free_k, - experimental_checkboxes, - progress - ) - except Exception as e: - message = str(e) - if message == "Kill signal detected": - reload_tts() - - raise gr.Error(message) - - + def update_voices(): return ( - outputs[0], - gr.update(value=sample, visible=sample is not None), - gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True), - gr.update(visible=len(outputs) > 1), - gr.update(value=stats, visible=True), + gr.Dropdown.update(choices=get_voice_list()), + gr.Dropdown.update(choices=get_voice_list("./results/")), ) + def history_copy_settings( voice, file ): + return import_generate_settings( f"./results/{voice}/{file}" ) + refresh_voices.click(update_voices, inputs=None, outputs=[ @@ -681,21 +489,12 @@ def setup_gradio(): outputs=input_settings ) - def reset_generation_settings(): - with open(f'./config/generate.json', 'w', encoding="utf-8") as f: - f.write(json.dumps({}, indent='\t') ) - return import_generate_settings() - reset_generation_settings_button.click( fn=reset_generation_settings, inputs=None, outputs=input_settings ) - def history_copy_settings( voice, file ): - settings = import_generate_settings( f"./results/{voice}/{file}" ) - return settings - history_copy_settings_button.click(history_copy_settings, inputs=[ history_voices, diff --git a/training/.template.yaml b/training/.template.yaml new file mode 100755 index 0000000..038437e --- /dev/null +++ b/training/.template.yaml @@ -0,0 +1,144 @@ +name: ${name} +model: extensibletrainer +scale: 1 +gpu_ids: [0] # <-- unless you have multiple gpus, use this +start_step: -1 +checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training. +fp16: false # might want to check this out +wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled. +use_tb_logger: true + +datasets: + train: + name: ${dataset_name} + n_workers: 8 # idk what this does + batch_size: ${batch_size} # This leads to ~16GB of vram usage on my 3090. + mode: paired_voice_audio + path: ${dataset_path} + fetcher_mode: ['lj'] # CHANGEME if your dataset isn't in LJSpeech format + phase: train + max_wav_length: 255995 + max_text_length: 200 + sample_rate: 22050 + load_conditioning: True + num_conditioning_candidates: 2 + conditioning_length: 44000 + use_bpe_tokenizer: True + load_aligned_codes: False + val: + name: ${validation_name} + n_workers: 1 + batch_size: 32 # this could be higher probably + mode: paired_voice_audio + path: ${validation_path} + fetcher_mode: ['lj'] + phase: val # might be broken idk + max_wav_length: 255995 + max_text_length: 200 + sample_rate: 22050 + load_conditioning: True + num_conditioning_candidates: 2 + conditioning_length: 44000 + use_bpe_tokenizer: True + load_aligned_codes: False + +steps: + gpt_train: + training: gpt + loss_log_buffer: 500 # no idea what this does + + # Generally follows the recipe from the DALLE paper. + optimizer: adamw # this should be adamw_zero if you're using distributed training + optimizer_params: + lr: !!float ${learning_rate} # CHANGEME: this was originally 1e-4; I reduced it to 1e-5 because it's fine-tuning, but **you should experiment with this value** + weight_decay: !!float 1e-2 + beta1: 0.9 + beta2: 0.96 + clip_grad_eps: 4 + + injectors: # TODO: replace this entire sequence with the GptVoiceLatentInjector + paired_to_mel: + type: torch_mel_spectrogram + mel_norm_file: ./experiments/clips_mel_norms.pth + in: wav + out: paired_mel + paired_cond_to_mel: + type: for_each + subtype: torch_mel_spectrogram + mel_norm_file: ./experiments/clips_mel_norms.pth + in: conditioning + out: paired_conditioning_mel + to_codes: + type: discrete_token + in: paired_mel + out: paired_mel_codes + dvae_config: "./experiments/train_diffusion_vocoder_22k_level.yml" # EXTREMELY IMPORTANT + paired_fwd_text: + type: generator + generator: gpt + in: [paired_conditioning_mel, padded_text, text_lengths, paired_mel_codes, wav_lengths] + out: [loss_text_ce, loss_mel_ce, logits] + losses: + text_ce: + type: direct + weight: .01 + key: loss_text_ce + mel_ce: + type: direct + weight: 1 + key: loss_mel_ce + +networks: + gpt: + type: generator + which_model_G: unified_voice2 # none of the unified_voice*.py files actually match the tortoise inference code... 4 and 3 have "alignment_head" (wtf is that?), 2 lacks the types=1 parameter. + kwargs: + layers: 30 # WAS 8 + model_dim: 1024 # WAS 512 + heads: 16 # WAS 8 + max_text_tokens: 402 # WAS 120 + max_mel_tokens: 604 # WAS 250 + max_conditioning_inputs: 2 # WAS 1 + mel_length_compression: 1024 + number_text_tokens: 256 # supposed to be 255 for newer unified_voice files + number_mel_codes: 8194 + start_mel_token: 8192 + stop_mel_token: 8193 + start_text_token: 255 + train_solo_embeddings: False # missing in uv3/4 + use_mel_codes_as_input: True # ditto + checkpointing: True + #types: 1 # this is MISSING, but in my analysis 1 is equivalent to not having it. + #only_alignment_head: False # uv3/4 + +path: + pretrain_model_gpt: './experiments/autoregressive.pth' # CHANGEME: copy this from tortoise cache + strict_load: true + #resume_state: ./experiments/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state. + +# afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit) +train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH + niter: 50000 + warmup_iter: -1 + mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. + val_freq: 500 + + default_lr_scheme: MultiStepLR + gen_lr_steps: [500, 1000, 1400, 1800] #[50000, 100000, 140000, 180000] + lr_gamma: 0.5 + +eval: + output_state: gen + injectors: + gen_inj_eval: + type: generator + generator: generator + in: hq + out: [gen, codebook_commitment_loss] + +logger: + print_freq: 100 + save_checkpoint_freq: 500 # CHANGEME: especially you should increase this it's really slow + visuals: [gen, mel] + visual_debug_rate: 500 + is_mel_spectrogram: true \ No newline at end of file