a bit of UI cleanup, import multiple audio files at once, actually shows progress when importing voices, hides audio metadata / latents if no generated settings are detected, preparing datasets shows its progress, saving a training YAML shows a message when done, training now works within the web UI, training output shows to web UI, provided notebook is cleaned up and uses a venv, etc.

This commit is contained in:
mrq 2023-02-18 02:07:22 +00:00
parent c75d0bc5da
commit d5c1433268
8 changed files with 323 additions and 259 deletions

View File

@ -3,10 +3,7 @@
"nbformat_minor":0, "nbformat_minor":0,
"metadata":{ "metadata":{
"colab":{ "colab":{
"private_outputs":true, "private_outputs":true
"provenance":[
]
}, },
"kernelspec":{ "kernelspec":{
"name":"python3", "name":"python3",
@ -40,41 +37,62 @@
"source":[ "source":[
"!git clone https://git.ecker.tech/mrq/ai-voice-cloning/\n", "!git clone https://git.ecker.tech/mrq/ai-voice-cloning/\n",
"%cd ai-voice-cloning\n", "%cd ai-voice-cloning\n",
"!apt install python3.8-venv\n",
"!python -m venv venv\n",
"!source ./venv/bin/activate\n",
"!git clone https://git.ecker.tech/mrq/DL-Art-School dlas\n",
"!python -m pip install --upgrade pip\n", "!python -m pip install --upgrade pip\n",
"!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116\n", "!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116\n",
"!python -m pip install -r ./requirements.txt\n", "!python -m pip install -r ./dlas/requirements.txt\n",
"!git clone https://git.ecker.tech/mrq/DL-Art-School dlas\n", "!python -m pip install -r ./requirements.txt"
"!python -m pip install -r ./dlas/requirements.txt"
] ]
}, },
{ {
"cell_type":"markdown", "cell_type":"markdown",
"source":[ "source":[
"# Restart Runtime Before Proceeding" "# Update Repos"
], ],
"metadata":{ "metadata":{
"id":"TXFyLVLA48S5" "id":"IzrGt5IcHlAD"
} }
}, },
{ {
"cell_type":"code", "cell_type":"code",
"source":[ "source":[
"# colab requires the runtime to restart before use\n", "%cd /content/ai-voice-cloning/dlas\n",
"exit()" "!git reset --hard HEAD\n",
"!git pull\n",
"%cd ..\n",
"!git reset --hard HEAD\n",
"!git pull\n",
"!python -m pip install ffmpeg ffmpeg-python"
], ],
"metadata":{ "metadata":{
"id":"FVUOtSASCSJ8" "id":"3DktoOXSHmtw"
}, },
"execution_count":null, "execution_count":null,
"outputs":[ "outputs":[
] ]
}, },
{
"cell_type":"markdown",
"source":[
"# Mount Drive"
],
"metadata":{
"id":"2Y4t9zDIZMTg"
}
},
{ {
"cell_type":"code", "cell_type":"code",
"source":[ "source":[
"from google.colab import drive\n", "from google.colab import drive\n",
"drive.mount('/content/drive')" "drive.mount('/content/drive')\n",
"\n",
"%cd /content/ai-voice-cloning\n",
"!rm -r ./training\n",
"!ln -s /content/drive/MyDrive/training/"
], ],
"metadata":{ "metadata":{
"id":"SGt9gyvubveT" "id":"SGt9gyvubveT"
@ -97,6 +115,8 @@
"cell_type":"code", "cell_type":"code",
"source":[ "source":[
"%cd /content/ai-voice-cloning\n", "%cd /content/ai-voice-cloning\n",
"!python -m venv venv\n",
"!source ./venv/bin/activate\n",
"\n", "\n",
"import os\n", "import os\n",
"import sys\n", "import sys\n",
@ -117,7 +137,7 @@
"\n", "\n",
"webui = setup_gradio()\n", "webui = setup_gradio()\n",
"tts = setup_tortoise()\n", "tts = setup_tortoise()\n",
"webui.launch(share=True, prevent_thread_lock=True, debug=True, height=1000)\n", "webui.launch(share=True, prevent_thread_lock=True, height=1000)\n",
"webui.block_thread()" "webui.block_thread()"
], ],
"metadata":{ "metadata":{
@ -140,8 +160,9 @@
{ {
"cell_type":"code", "cell_type":"code",
"source":[ "source":[
"# This is in case you can't get training through the web UI\n",
"%cd /content/ai-voice-cloning\n", "%cd /content/ai-voice-cloning\n",
"!python ./src/train.py -opt ./training/finetune.yaml" "!python ./dlas/codes/train.py -opt ./training/finetune.yaml"
], ],
"metadata":{ "metadata":{
"id":"-KayB8klA5tY" "id":"-KayB8klA5tY"
@ -167,8 +188,9 @@
"!apt install -y p7zip-full\n", "!apt install -y p7zip-full\n",
"from datetime import datetime\n", "from datetime import datetime\n",
"timestamp = datetime.now().strftime('%m-%d-%Y_%H:%M:%S')\n", "timestamp = datetime.now().strftime('%m-%d-%Y_%H:%M:%S')\n",
"!mkdir -p \"../{timestamp}\"\n", "!mkdir -p \"../{timestamp}/results\"\n",
"!mv ./results/* \"../{timestamp}/.\"\n", "!mv ./results/* \"../{timestamp}/results/.\"\n",
"!mv ./training/* \"../{timestamp}/training/.\"\n",
"!7z a -t7z -m0=lzma2 -mx=9 -mfb=64 -md=32m -ms=on \"../{timestamp}.7z\" \"../{timestamp}/\"\n", "!7z a -t7z -m0=lzma2 -mx=9 -mfb=64 -md=32m -ms=on \"../{timestamp}.7z\" \"../{timestamp}/\"\n",
"!ls ~/\n", "!ls ~/\n",
"!echo \"Finished zipping, archive is available at {timestamp}.7z\"" "!echo \"Finished zipping, archive is available at {timestamp}.7z\""

View File

@ -25,32 +25,37 @@ from utils import util, options as option
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py # this is effectively just copy pasted and cleaned up from the __main__ section of training.py
# I'll clean it up better # I'll clean it up better
parser = argparse.ArgumentParser() def train(yaml, launcher='none'):
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml') opt = option.parse(yaml, is_train=True)
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') if launcher != 'none':
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
if args.launcher != 'none':
# export CUDA_VISIBLE_DEVICES for running in distributed mode. # export CUDA_VISIBLE_DEVICES for running in distributed mode.
if 'gpu_ids' in opt.keys(): if 'gpu_ids' in opt.keys():
gpu_list = ','.join(str(x) for x in opt['gpu_ids']) gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list) print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
trainer = tr.Trainer() trainer = tr.Trainer()
#### distributed training settings #### distributed training settings
if args.launcher == 'none': # disabled distributed training if launcher == 'none': # disabled distributed training
opt['dist'] = False opt['dist'] = False
trainer.rank = -1 trainer.rank = -1
if len(opt['gpu_ids']) == 1: if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0]) torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.') print('Disabled distributed training.')
else: else:
opt['dist'] = True opt['dist'] = True
init_dist('nccl') init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size() trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank() trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank()) torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(args.opt, opt, args.launcher) trainer.init(yaml, opt, launcher)
trainer.do_training() trainer.do_training()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
train(args.opt, args.launcher)

View File

@ -1,5 +1,4 @@
import os import os
if 'XDG_CACHE_HOME' not in os.environ: if 'XDG_CACHE_HOME' not in os.environ:
os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/')) os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/'))
@ -15,7 +14,9 @@ import json
import base64 import base64
import re import re
import urllib.request import urllib.request
import signal
import tqdm
import torch import torch
import torchaudio import torchaudio
import music_tag import music_tag
@ -90,6 +91,8 @@ def setup_args():
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once") parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)") parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output") parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
args = parser.parse_args() args = parser.parse_args()
args.embed_output_metadata = not args.no_embed_output_metadata args.embed_output_metadata = not args.no_embed_output_metadata
@ -427,20 +430,37 @@ def generate(
import subprocess import subprocess
training_process = None
def run_training(config_path): def run_training(config_path):
print("Unloading TTS to save VRAM.") print("Unloading TTS to save VRAM.")
global tts global tts
del tts del tts
tts = None tts = None
cmd = ["python", "./src/train.py", "-opt", config_path] global training_process
torch.multiprocessing.freeze_support()
cmd = [f'train.{"bat" if args.os == "windows" else "sh"}', config_path]
print("Spawning process: ", " ".join(cmd)) print("Spawning process: ", " ".join(cmd))
subprocess.run(cmd, env=os.environ.copy(), shell=True) training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
""" buffer=[]
from train import train for line in iter(training_process.stdout.readline, ""):
train(config) buffer.append(line)
""" yield "".join(buffer)
training_process.stdout.close()
return_code = training_process.wait()
training_process = None
if return_code:
raise subprocess.CalledProcessError(return_code, cmd)
def stop_training():
if training_process is None:
return "No training in progress"
training_process.kill()
training_process = None
return "Training cancelled"
def setup_voicefixer(restart=False): def setup_voicefixer(restart=False):
global voicefixer global voicefixer
@ -485,19 +505,23 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
"validation_name": validation_name if validation_name else "finetune", "validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./training/finetune/train.txt", "validation_path": validation_path if validation_path else "./training/finetune/train.txt",
} }
outfile = f'./training/{settings["name"]}.yaml'
with open(f'./training/.template.yaml', 'r', encoding="utf-8") as f: with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
yaml = f.read() yaml = f.read()
for k in settings: for k in settings:
yaml = yaml.replace(f"${{{k}}}", str(settings[k])) yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f: with open(outfile, 'w', encoding="utf-8") as f:
f.write(yaml) f.write(yaml)
def prepare_dataset( files, outdir, language=None ): return f"Training settings saved to: {outfile}"
def prepare_dataset( files, outdir, language=None, progress=None ):
global whisper_model global whisper_model
if whisper_model is None: if whisper_model is None:
notify_progress(f"Loading Whisper model: {args.whisper_model}", progress)
whisper_model = whisper.load_model(args.whisper_model) whisper_model = whisper.load_model(args.whisper_model)
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
@ -506,7 +530,7 @@ def prepare_dataset( files, outdir, language=None ):
results = {} results = {}
transcription = [] transcription = []
for file in files: for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
print(f"Transcribing file: {file}") print(f"Transcribing file: {file}")
result = whisper_model.transcribe(file, language=language if language else "English") result = whisper_model.transcribe(file, language=language if language else "English")
@ -517,7 +541,7 @@ def prepare_dataset( files, outdir, language=None ):
waveform, sampling_rate = torchaudio.load(file) waveform, sampling_rate = torchaudio.load(file)
num_channels, num_frames = waveform.shape num_channels, num_frames = waveform.shape
for segment in result['segments']: for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int(segment['start'] * sampling_rate) start = int(segment['start'] * sampling_rate)
end = int(segment['end'] * sampling_rate) end = int(segment['end'] * sampling_rate)
@ -535,14 +559,20 @@ def prepare_dataset( files, outdir, language=None ):
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f: with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(transcription)) f.write("\n".join(transcription))
return f"Processed dataset to: {outdir}"
def reset_generation_settings(): def reset_generation_settings():
with open(f'./config/generate.json', 'w', encoding="utf-8") as f: with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps({}, indent='\t') ) f.write(json.dumps({}, indent='\t') )
return import_generate_settings() return import_generate_settings()
def import_voice(file, saveAs = None): def import_voices(files, saveAs=None, progress=None):
global args global args
if not isinstance(files, list):
files = [files]
for file in enumerate_progress(files, desc="Importing voice files", progress=progress):
j, latents = read_generate_settings(file, read_latents=True) j, latents = read_generate_settings(file, read_latents=True)
if j is not None and saveAs is None: if j is not None and saveAs is None:
@ -552,7 +582,9 @@ def import_voice(file, saveAs = None):
outdir = f'{get_voice_dir()}/{saveAs}/' outdir = f'{get_voice_dir()}/{saveAs}/'
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
if latents: if latents:
print(f"Importing latents to {latents}")
with open(f'{outdir}/cond_latents.pth', 'wb') as f: with open(f'{outdir}/cond_latents.pth', 'wb') as f:
f.write(latents) f.write(latents)
latents = f'{outdir}/cond_latents.pth' latents = f'{outdir}/cond_latents.pth'
@ -563,6 +595,8 @@ def import_voice(file, saveAs = None):
raise Exception("Please convert to a WAV first") raise Exception("Please convert to a WAV first")
path = f"{outdir}/{os.path.basename(filename)}" path = f"{outdir}/{os.path.basename(filename)}"
print(f"Importing voice to {path}")
waveform, sampling_rate = torchaudio.load(filename) waveform, sampling_rate = torchaudio.load(filename)
if args.voice_fixer and voicefixer is not None: if args.voice_fixer and voicefixer is not None:
@ -592,10 +626,8 @@ def import_voice(file, saveAs = None):
else: else:
torchaudio.save(path, waveform, sampling_rate) torchaudio.save(path, waveform, sampling_rate)
print(f"Imported voice to {path}") print(f"Imported voice to {path}")
def import_generate_settings(file="./config/generate.json"): def import_generate_settings(file="./config/generate.json"):
settings, _ = read_generate_settings(file, read_latents=False) settings, _ = read_generate_settings(file, read_latents=False)
@ -760,3 +792,20 @@ def read_generate_settings(file, read_latents=True, read_json=True):
j, j,
latents, latents,
) )
def enumerate_progress(iterable, desc=None, progress=None, verbose=None):
if verbose and desc is not None:
print(desc)
if progress is None:
return tqdm(iterable, disable=not verbose)
return progress.tqdm(iterable, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc, track_tqdm=True)
def notify_progress(message, progress=None, verbose=True):
if verbose:
print(message)
if progress is None:
return
progress(0, desc=message)

View File

@ -135,6 +135,21 @@ def get_training_configs():
def update_training_configs(): def update_training_configs():
return gr.update(choices=get_training_configs()) return gr.update(choices=get_training_configs())
history_headers = {
"Name": "",
"Samples": "num_autoregressive_samples",
"Iterations": "diffusion_iterations",
"Temp.": "temperature",
"Sampler": "diffusion_sampler",
"CVVP": "cvvp_weight",
"Top P": "top_p",
"Diff. Temp.": "diffusion_temperature",
"Len Pen": "length_penalty",
"Rep Pen": "repetition_penalty",
"Cond-Free K": "cond_free_k",
"Time": "time",
}
def history_view_results( voice ): def history_view_results( voice ):
results = [] results = []
files = [] files = []
@ -148,7 +163,7 @@ def history_view_results( voice ):
continue continue
values = [] values = []
for k in headers: for k in history_headers:
v = file v = file
if k != "Name": if k != "Name":
v = metadata[headers[k]] v = metadata[headers[k]]
@ -163,6 +178,10 @@ def history_view_results( voice ):
gr.Dropdown.update(choices=sorted(files)) gr.Dropdown.update(choices=sorted(files))
) )
def import_voices_proxy(files, name, progress=gr.Progress(track_tqdm=True)):
import_voices(files, name, progress)
return gr.update()
def read_generate_settings_proxy(file, saveAs='.temp'): def read_generate_settings_proxy(file, saveAs='.temp'):
j, latents = read_generate_settings(file) j, latents = read_generate_settings(file)
@ -175,13 +194,14 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
latents = f'{outdir}/cond_latents.pth' latents = f'{outdir}/cond_latents.pth'
return ( return (
j, gr.update(value=j, visible=j is not None),
gr.update(visible=j is not None),
gr.update(value=latents, visible=latents is not None), gr.update(value=latents, visible=latents is not None),
None if j is None else j['voice'] None if j is None else j['voice']
) )
def prepare_dataset_proxy( voice, language ): def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language ) return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
def update_voices(): def update_voices():
return ( return (
@ -222,52 +242,18 @@ def setup_gradio():
with gr.Column(): with gr.Column():
delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n") delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
emotion = gr.Radio( emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"],
value="Custom",
label="Emotion",
type="value",
interactive=True
)
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)") prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
voice = gr.Dropdown( voice = gr.Dropdown(get_voice_list(), label="Voice", type="value")
get_voice_list(), mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
label="Voice",
type="value",
)
mic_audio = gr.Audio(
label="Microphone Source",
source="microphone",
type="filepath",
)
refresh_voices = gr.Button(value="Refresh Voice List") refresh_voices = gr.Button(value="Refresh Voice List")
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1) voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents") recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
recompute_voice_latents.click(compute_latents,
inputs=[
voice,
voice_latents_chunks,
],
outputs=voice,
)
prompt.change(fn=lambda value: gr.update(value="Custom"),
inputs=prompt,
outputs=emotion
)
mic_audio.change(fn=lambda value: gr.update(value="microphone"),
inputs=mic_audio,
outputs=voice
)
with gr.Column(): with gr.Column():
candidates = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates") candidates = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
seed = gr.Number(value=0, precision=0, label="Seed") seed = gr.Number(value=0, precision=0, label="Seed")
preset = gr.Radio( preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" )
["Ultra Fast", "Fast", "Standard", "High Quality"],
label="Preset",
type="value",
)
num_autoregressive_samples = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Samples") num_autoregressive_samples = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Samples")
diffusion_iterations = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations") diffusion_iterations = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations")
@ -275,19 +261,7 @@ def setup_gradio():
breathing_room = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size") breathing_room = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size")
diffusion_sampler = gr.Radio( diffusion_sampler = gr.Radio(
["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"], ["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"],
value="P", value="P", label="Diffusion Samplers", type="value" )
label="Diffusion Samplers",
type="value",
)
preset.change(fn=update_presets,
inputs=preset,
outputs=[
num_autoregressive_samples,
diffusion_iterations,
],
)
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings") show_experimental_settings = gr.Checkbox(label="Show Experimental Settings")
reset_generation_settings_button = gr.Button(value="Reset to Default") reset_generation_settings_button = gr.Button(value="Reset to Default")
with gr.Column(visible=False) as col: with gr.Column(visible=False) as col:
@ -300,12 +274,6 @@ def setup_gradio():
length_penalty = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty") length_penalty = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty")
repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty") repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty")
cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K") cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K")
show_experimental_settings.change(
fn=lambda x: gr.update(visible=x),
inputs=show_experimental_settings,
outputs=experimental_column
)
with gr.Column(): with gr.Column():
submit = gr.Button(value="Generate") submit = gr.Button(value="Generate")
stop = gr.Button(value="Stop") stop = gr.Button(value="Stop")
@ -315,33 +283,13 @@ def setup_gradio():
output_audio = gr.Audio(label="Output") output_audio = gr.Audio(label="Output")
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False) candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False)
output_pick = gr.Button(value="Select Candidate", visible=False) output_pick = gr.Button(value="Select Candidate", visible=False)
with gr.Tab("History"): with gr.Tab("History"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
headers = { history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys()))
"Name": "",
"Samples": "num_autoregressive_samples",
"Iterations": "diffusion_iterations",
"Temp.": "temperature",
"Sampler": "diffusion_sampler",
"CVVP": "cvvp_weight",
"Top P": "top_p",
"Diff. Temp.": "diffusion_temperature",
"Len Pen": "length_penalty",
"Rep Pen": "repetition_penalty",
"Cond-Free K": "cond_free_k",
"Time": "time",
}
history_info = gr.Dataframe(label="Results", headers=list(headers.keys()))
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
history_voices = gr.Dropdown( history_voices = gr.Dropdown(choices=get_voice_list("./results/"), label="Voice", type="value")
get_voice_list("./results/"),
label="Voice",
type="value",
)
history_view_results_button = gr.Button(value="View Files") history_view_results_button = gr.Button(value="View Files")
with gr.Column(): with gr.Column():
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True) history_results_list = gr.Dropdown(label="Results",type="value", interactive=True)
@ -349,51 +297,16 @@ def setup_gradio():
with gr.Column(): with gr.Column():
history_audio = gr.Audio() history_audio = gr.Audio()
history_copy_settings_button = gr.Button(value="Copy Settings") history_copy_settings_button = gr.Button(value="Copy Settings")
history_view_results_button.click(
fn=history_view_results,
inputs=history_voices,
outputs=[
history_info,
history_results_list,
]
)
history_view_result_button.click(
fn=lambda voice, file: f"./results/{voice}/{file}",
inputs=[
history_voices,
history_results_list,
],
outputs=history_audio
)
with gr.Tab("Utilities"): with gr.Tab("Utilities"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
audio_in = gr.File(type="file", label="Audio Input", file_types=["audio"]) audio_in = gr.Files(type="file", label="Audio Input", file_types=["audio"])
copy_button = gr.Button(value="Copy Settings")
import_voice_name = gr.Textbox(label="Voice Name") import_voice_name = gr.Textbox(label="Voice Name")
import_voice_button = gr.Button(value="Import Voice") import_voice_button = gr.Button(value="Import Voice")
with gr.Column(): with gr.Column():
metadata_out = gr.JSON(label="Audio Metadata") metadata_out = gr.JSON(label="Audio Metadata", visible=False)
latents_out = gr.File(type="binary", label="Voice Latents") copy_button = gr.Button(value="Copy Settings", visible=False)
latents_out = gr.File(type="binary", label="Voice Latents", visible=False)
audio_in.upload(
fn=read_generate_settings_proxy,
inputs=audio_in,
outputs=[
metadata_out,
latents_out,
import_voice_name
]
)
import_voice_button.click(
fn=import_voice,
inputs=[
audio_in,
import_voice_name,
]
)
with gr.Tab("Training"): with gr.Tab("Training"):
with gr.Tab("Prepare Dataset"): with gr.Tab("Prepare Dataset"):
with gr.Row(): with gr.Row():
@ -402,16 +315,9 @@ def setup_gradio():
gr.Dropdown( get_voice_list(), label="Dataset Source", type="value" ), gr.Dropdown( get_voice_list(), label="Dataset Source", type="value" ),
gr.Textbox(label="Language", placeholder="English") gr.Textbox(label="Language", placeholder="English")
] ]
dataset_voices = dataset_settings[0]
with gr.Column():
prepare_dataset_button = gr.Button(value="Prepare") prepare_dataset_button = gr.Button(value="Prepare")
with gr.Column():
prepare_dataset_button.click( prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
prepare_dataset_proxy,
inputs=dataset_settings,
outputs=None
)
with gr.Tab("Generate Configuration"): with gr.Tab("Generate Configuration"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -421,8 +327,6 @@ def setup_gradio():
gr.Number(label="Print Frequency", value=50), gr.Number(label="Print Frequency", value=50),
gr.Number(label="Save Frequency", value=50), gr.Number(label="Save Frequency", value=50),
] ]
save_yaml_button = gr.Button(value="Save Training Configuration")
with gr.Column():
training_settings = training_settings + [ training_settings = training_settings + [
gr.Textbox(label="Training Name", placeholder="finetune"), gr.Textbox(label="Training Name", placeholder="finetune"),
gr.Textbox(label="Dataset Name", placeholder="finetune"), gr.Textbox(label="Dataset Name", placeholder="finetune"),
@ -430,24 +334,18 @@ def setup_gradio():
gr.Textbox(label="Validation Name", placeholder="finetune"), gr.Textbox(label="Validation Name", placeholder="finetune"),
gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"), gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"),
] ]
with gr.Column():
save_yaml_button.click(save_training_settings, save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
inputs=training_settings, save_yaml_button = gr.Button(value="Save Training Configuration")
outputs=None with gr.Tab("Run Training"):
)
with gr.Tab("Train"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs()) training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
refresh_configs = gr.Button(value="Refresh Configurations") refresh_configs = gr.Button(value="Refresh Configurations")
train = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop")
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs) with gr.Column():
train.click(run_training, training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
inputs=training_configs,
outputs=None
)
with gr.Tab("Settings"): with gr.Tab("Settings"):
with gr.Row(): with gr.Row():
exec_inputs = [] exec_inputs = []
@ -465,23 +363,22 @@ def setup_gradio():
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents), gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
gr.Checkbox(label="Defer TTS Load", value=args.defer_tts_load), gr.Checkbox(label="Defer TTS Load", value=args.defer_tts_load),
gr.Textbox(label="Device Override", value=args.device_override), gr.Textbox(label="Device Override", value=args.device_override),
gr.Dropdown(label="Whisper Model", value=args.whisper_model, choices=["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"]),
] ]
gr.Button(value="Check for Updates").click(check_for_updates)
gr.Button(value="Reload TTS").click(reload_tts)
with gr.Column(): with gr.Column():
exec_inputs = exec_inputs + [ exec_inputs = exec_inputs + [
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size), gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count), gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count),
gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate), gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate),
gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume), gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume),
gr.Dropdown(label="Whisper Model", value=args.whisper_model, choices=["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"]),
] ]
gr.Button(value="Check for Updates").click(check_for_updates)
gr.Button(value="Reload TTS").click(reload_tts)
for i in exec_inputs: for i in exec_inputs:
i.change( i.change( fn=export_exec_settings, inputs=exec_inputs )
fn=export_exec_settings,
inputs=exec_inputs # console_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
)
input_settings = [ input_settings = [
text, text,
@ -507,11 +404,76 @@ def setup_gradio():
experimental_checkboxes, experimental_checkboxes,
] ]
history_view_results_button.click(
fn=history_view_results,
inputs=history_voices,
outputs=[
history_info,
history_results_list,
]
)
history_view_result_button.click(
fn=lambda voice, file: f"./results/{voice}/{file}",
inputs=[
history_voices,
history_results_list,
],
outputs=history_audio
)
audio_in.upload(
fn=read_generate_settings_proxy,
inputs=audio_in,
outputs=[
metadata_out,
copy_button,
latents_out,
import_voice_name
]
)
import_voice_button.click(
fn=import_voices_proxy,
inputs=[
audio_in,
import_voice_name,
],
outputs=import_voice_name #console_output
)
show_experimental_settings.change(
fn=lambda x: gr.update(visible=x),
inputs=show_experimental_settings,
outputs=experimental_column
)
preset.change(fn=update_presets,
inputs=preset,
outputs=[
num_autoregressive_samples,
diffusion_iterations,
],
)
recompute_voice_latents.click(compute_latents,
inputs=[
voice,
voice_latents_chunks,
],
outputs=voice,
)
prompt.change(fn=lambda value: gr.update(value="Custom"),
inputs=prompt,
outputs=emotion
)
mic_audio.change(fn=lambda value: gr.update(value="microphone"),
inputs=mic_audio,
outputs=voice
)
refresh_voices.click(update_voices, refresh_voices.click(update_voices,
inputs=None, inputs=None,
outputs=[ outputs=[
voice, voice,
dataset_voices, dataset_settings[0],
history_voices history_voices
] ]
) )
@ -552,6 +514,25 @@ def setup_gradio():
outputs=input_settings outputs=input_settings
) )
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
start_training_button.click(run_training,
inputs=training_configs,
outputs=training_output #console_output
)
stop_training_button.click(stop_training,
inputs=None,
outputs=training_output #console_output
)
prepare_dataset_button.click(
prepare_dataset_proxy,
inputs=dataset_settings,
outputs=prepare_dataset_output #console_output
)
save_yaml_button.click(save_training_settings,
inputs=training_settings,
outputs=save_yaml_output #console_output
)
if os.path.isfile('./config/generate.json'): if os.path.isfile('./config/generate.json'):
ui.load(import_generate_settings, inputs=None, outputs=input_settings) ui.load(import_generate_settings, inputs=None, outputs=input_settings)

4
train.bat Executable file
View File

@ -0,0 +1,4 @@
call .\venv\Scripts\activate.bat
python ./src/train.py -opt "%1"
deactivate
pause

3
train.sh Executable file
View File

@ -0,0 +1,3 @@
source ./venv/bin/activate
python3 ./src/train.py -opt "$1"
deactivate

0
training/.gitkeep Executable file
View File