1
0
Fork 0

a bit of UI cleanup, import multiple audio files at once, actually shows progress when importing voices, hides audio metadata / latents if no generated settings are detected, preparing datasets shows its progress, saving a training YAML shows a message when done, training now works within the web UI, training output shows to web UI, provided notebook is cleaned up and uses a venv, etc.

remotes/1708699347150643056/master
mrq 2023-02-18 02:07:22 +07:00
parent c75d0bc5da
commit d5c1433268
8 changed files with 329 additions and 265 deletions

@ -3,10 +3,7 @@
"nbformat_minor":0,
"metadata":{
"colab":{
"private_outputs":true,
"provenance":[
]
"private_outputs":true
},
"kernelspec":{
"name":"python3",
@ -40,41 +37,62 @@
"source":[
"!git clone https://git.ecker.tech/mrq/ai-voice-cloning/\n",
"%cd ai-voice-cloning\n",
"!apt install python3.8-venv\n",
"!python -m venv venv\n",
"!source ./venv/bin/activate\n",
"!git clone https://git.ecker.tech/mrq/DL-Art-School dlas\n",
"!python -m pip install --upgrade pip\n",
"!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116\n",
"!python -m pip install -r ./requirements.txt\n",
"!git clone https://git.ecker.tech/mrq/DL-Art-School dlas\n",
"!python -m pip install -r ./dlas/requirements.txt"
"!python -m pip install -r ./dlas/requirements.txt\n",
"!python -m pip install -r ./requirements.txt"
]
},
{
"cell_type":"markdown",
"source":[
"# Restart Runtime Before Proceeding"
"# Update Repos"
],
"metadata":{
"id":"TXFyLVLA48S5"
"id":"IzrGt5IcHlAD"
}
},
{
"cell_type":"code",
"source":[
"# colab requires the runtime to restart before use\n",
"exit()"
"%cd /content/ai-voice-cloning/dlas\n",
"!git reset --hard HEAD\n",
"!git pull\n",
"%cd ..\n",
"!git reset --hard HEAD\n",
"!git pull\n",
"!python -m pip install ffmpeg ffmpeg-python"
],
"metadata":{
"id":"FVUOtSASCSJ8"
"id":"3DktoOXSHmtw"
},
"execution_count":null,
"outputs":[
]
},
{
"cell_type":"markdown",
"source":[
"# Mount Drive"
],
"metadata":{
"id":"2Y4t9zDIZMTg"
}
},
{
"cell_type":"code",
"source":[
"from google.colab import drive\n",
"drive.mount('/content/drive')"
"drive.mount('/content/drive')\n",
"\n",
"%cd /content/ai-voice-cloning\n",
"!rm -r ./training\n",
"!ln -s /content/drive/MyDrive/training/"
],
"metadata":{
"id":"SGt9gyvubveT"
@ -97,6 +115,8 @@
"cell_type":"code",
"source":[
"%cd /content/ai-voice-cloning\n",
"!python -m venv venv\n",
"!source ./venv/bin/activate\n",
"\n",
"import os\n",
"import sys\n",
@ -117,7 +137,7 @@
"\n",
"webui = setup_gradio()\n",
"tts = setup_tortoise()\n",
"webui.launch(share=True, prevent_thread_lock=True, debug=True, height=1000)\n",
"webui.launch(share=True, prevent_thread_lock=True, height=1000)\n",
"webui.block_thread()"
],
"metadata":{
@ -140,8 +160,9 @@
{
"cell_type":"code",
"source":[
"# This is in case you can't get training through the web UI\n",
"%cd /content/ai-voice-cloning\n",
"!python ./src/train.py -opt ./training/finetune.yaml"
"!python ./dlas/codes/train.py -opt ./training/finetune.yaml"
],
"metadata":{
"id":"-KayB8klA5tY"
@ -167,8 +188,9 @@
"!apt install -y p7zip-full\n",
"from datetime import datetime\n",
"timestamp = datetime.now().strftime('%m-%d-%Y_%H:%M:%S')\n",
"!mkdir -p \"../{timestamp}\"\n",
"!mv ./results/* \"../{timestamp}/.\"\n",
"!mkdir -p \"../{timestamp}/results\"\n",
"!mv ./results/* \"../{timestamp}/results/.\"\n",
"!mv ./training/* \"../{timestamp}/training/.\"\n",
"!7z a -t7z -m0=lzma2 -mx=9 -mfb=64 -md=32m -ms=on \"../{timestamp}.7z\" \"../{timestamp}/\"\n",
"!ls ~/\n",
"!echo \"Finished zipping, archive is available at {timestamp}.7z\""

@ -25,32 +25,37 @@ from utils import util, options as option
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py
# I'll clean it up better
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
if args.launcher != 'none':
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
if 'gpu_ids' in opt.keys():
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
trainer = tr.Trainer()
#### distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(args.opt, opt, args.launcher)
trainer.do_training()
def train(yaml, launcher='none'):
opt = option.parse(yaml, is_train=True)
if launcher != 'none':
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
if 'gpu_ids' in opt.keys():
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
trainer = tr.Trainer()
#### distributed training settings
if launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(yaml, opt, launcher)
trainer.do_training()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
train(args.opt, args.launcher)

@ -1,5 +1,4 @@
import os
if 'XDG_CACHE_HOME' not in os.environ:
os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/'))
@ -15,7 +14,9 @@ import json
import base64
import re
import urllib.request
import signal
import tqdm
import torch
import torchaudio
import music_tag
@ -90,6 +91,8 @@ def setup_args():
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
args = parser.parse_args()
args.embed_output_metadata = not args.no_embed_output_metadata
@ -427,20 +430,37 @@ def generate(
import subprocess
training_process = None
def run_training(config_path):
print("Unloading TTS to save VRAM.")
global tts
del tts
tts = None
cmd = ["python", "./src/train.py", "-opt", config_path]
global training_process
torch.multiprocessing.freeze_support()
cmd = [f'train.{"bat" if args.os == "windows" else "sh"}', config_path]
print("Spawning process: ", " ".join(cmd))
subprocess.run(cmd, env=os.environ.copy(), shell=True)
"""
from train import train
train(config)
"""
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
buffer=[]
for line in iter(training_process.stdout.readline, ""):
buffer.append(line)
yield "".join(buffer)
training_process.stdout.close()
return_code = training_process.wait()
training_process = None
if return_code:
raise subprocess.CalledProcessError(return_code, cmd)
def stop_training():
if training_process is None:
return "No training in progress"
training_process.kill()
training_process = None
return "Training cancelled"
def setup_voicefixer(restart=False):
global voicefixer
@ -485,19 +505,23 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
"validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
}
outfile = f'./training/{settings["name"]}.yaml'
with open(f'./training/.template.yaml', 'r', encoding="utf-8") as f:
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
yaml = f.read()
for k in settings:
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
with open(outfile, 'w', encoding="utf-8") as f:
f.write(yaml)
def prepare_dataset( files, outdir, language=None ):
return f"Training settings saved to: {outfile}"
def prepare_dataset( files, outdir, language=None, progress=None ):
global whisper_model
if whisper_model is None:
notify_progress(f"Loading Whisper model: {args.whisper_model}", progress)
whisper_model = whisper.load_model(args.whisper_model)
os.makedirs(outdir, exist_ok=True)
@ -506,7 +530,7 @@ def prepare_dataset( files, outdir, language=None ):
results = {}
transcription = []
for file in files:
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
print(f"Transcribing file: {file}")
result = whisper_model.transcribe(file, language=language if language else "English")
@ -517,7 +541,7 @@ def prepare_dataset( files, outdir, language=None ):
waveform, sampling_rate = torchaudio.load(file)
num_channels, num_frames = waveform.shape
for segment in result['segments']:
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int(segment['start'] * sampling_rate)
end = int(segment['end'] * sampling_rate)
@ -535,66 +559,74 @@ def prepare_dataset( files, outdir, language=None ):
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(transcription))
return f"Processed dataset to: {outdir}"
def reset_generation_settings():
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps({}, indent='\t') )
return import_generate_settings()
def import_voice(file, saveAs = None):
def import_voices(files, saveAs=None, progress=None):
global args
j, latents = read_generate_settings(file, read_latents=True)
if j is not None and saveAs is None:
saveAs = j['voice']
if saveAs is None or saveAs == "":
raise Exception("Specify a voice name")
if not isinstance(files, list):
files = [files]
outdir = f'{get_voice_dir()}/{saveAs}/'
os.makedirs(outdir, exist_ok=True)
if latents:
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
f.write(latents)
latents = f'{outdir}/cond_latents.pth'
print(f"Imported latents to {latents}")
else:
filename = file.name
if filename[-4:] != ".wav":
raise Exception("Please convert to a WAV first")
path = f"{outdir}/{os.path.basename(filename)}"
waveform, sampling_rate = torchaudio.load(filename)
if args.voice_fixer and voicefixer is not None:
# resample to best bandwidth since voicefixer will do it anyways through librosa
if sampling_rate != 44100:
print(f"Resampling imported voice sample: {path}")
resampler = torchaudio.transforms.Resample(
sampling_rate,
44100,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
waveform = resampler(waveform)
sampling_rate = 44100
torchaudio.save(path, waveform, sampling_rate)
print(f"Running 'voicefixer' on voice sample: {path}")
voicefixer.restore(
input = path,
output = path,
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
#mode=mode,
)
for file in enumerate_progress(files, desc="Importing voice files", progress=progress):
j, latents = read_generate_settings(file, read_latents=True)
if j is not None and saveAs is None:
saveAs = j['voice']
if saveAs is None or saveAs == "":
raise Exception("Specify a voice name")
outdir = f'{get_voice_dir()}/{saveAs}/'
os.makedirs(outdir, exist_ok=True)
if latents:
print(f"Importing latents to {latents}")
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
f.write(latents)
latents = f'{outdir}/cond_latents.pth'
print(f"Imported latents to {latents}")
else:
torchaudio.save(path, waveform, sampling_rate)
print(f"Imported voice to {path}")
filename = file.name
if filename[-4:] != ".wav":
raise Exception("Please convert to a WAV first")
path = f"{outdir}/{os.path.basename(filename)}"
print(f"Importing voice to {path}")
waveform, sampling_rate = torchaudio.load(filename)
if args.voice_fixer and voicefixer is not None:
# resample to best bandwidth since voicefixer will do it anyways through librosa
if sampling_rate != 44100:
print(f"Resampling imported voice sample: {path}")
resampler = torchaudio.transforms.Resample(
sampling_rate,
44100,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
waveform = resampler(waveform)
sampling_rate = 44100
torchaudio.save(path, waveform, sampling_rate)
print(f"Running 'voicefixer' on voice sample: {path}")
voicefixer.restore(
input = path,
output = path,
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
#mode=mode,
)
else:
torchaudio.save(path, waveform, sampling_rate)
print(f"Imported voice to {path}")
def import_generate_settings(file="./config/generate.json"):
settings, _ = read_generate_settings(file, read_latents=False)
@ -759,4 +791,21 @@ def read_generate_settings(file, read_latents=True, read_json=True):
return (
j,
latents,
)
)
def enumerate_progress(iterable, desc=None, progress=None, verbose=None):
if verbose and desc is not None:
print(desc)
if progress is None:
return tqdm(iterable, disable=not verbose)
return progress.tqdm(iterable, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc, track_tqdm=True)
def notify_progress(message, progress=None, verbose=True):
if verbose:
print(message)
if progress is None:
return
progress(0, desc=message)

@ -135,6 +135,21 @@ def get_training_configs():
def update_training_configs():
return gr.update(choices=get_training_configs())
history_headers = {
"Name": "",
"Samples": "num_autoregressive_samples",
"Iterations": "diffusion_iterations",
"Temp.": "temperature",
"Sampler": "diffusion_sampler",
"CVVP": "cvvp_weight",
"Top P": "top_p",
"Diff. Temp.": "diffusion_temperature",
"Len Pen": "length_penalty",
"Rep Pen": "repetition_penalty",
"Cond-Free K": "cond_free_k",
"Time": "time",
}
def history_view_results( voice ):
results = []
files = []
@ -148,7 +163,7 @@ def history_view_results( voice ):
continue
values = []
for k in headers:
for k in history_headers:
v = file
if k != "Name":
v = metadata[headers[k]]
@ -163,6 +178,10 @@ def history_view_results( voice ):
gr.Dropdown.update(choices=sorted(files))
)
def import_voices_proxy(files, name, progress=gr.Progress(track_tqdm=True)):
import_voices(files, name, progress)
return gr.update()
def read_generate_settings_proxy(file, saveAs='.temp'):
j, latents = read_generate_settings(file)
@ -175,13 +194,14 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
latents = f'{outdir}/cond_latents.pth'
return (
j,
gr.update(value=j, visible=j is not None),
gr.update(visible=j is not None),
gr.update(value=latents, visible=latents is not None),
None if j is None else j['voice']
)
def prepare_dataset_proxy( voice, language ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language )
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
def update_voices():
return (
@ -222,52 +242,18 @@ def setup_gradio():
with gr.Column():
delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
emotion = gr.Radio(
["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"],
value="Custom",
label="Emotion",
type="value",
interactive=True
)
emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
voice = gr.Dropdown(
get_voice_list(),
label="Voice",
type="value",
)
mic_audio = gr.Audio(
label="Microphone Source",
source="microphone",
type="filepath",
)
voice = gr.Dropdown(get_voice_list(), label="Voice", type="value")
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
refresh_voices = gr.Button(value="Refresh Voice List")
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
recompute_voice_latents.click(compute_latents,
inputs=[
voice,
voice_latents_chunks,
],
outputs=voice,
)
prompt.change(fn=lambda value: gr.update(value="Custom"),
inputs=prompt,
outputs=emotion
)
mic_audio.change(fn=lambda value: gr.update(value="microphone"),
inputs=mic_audio,
outputs=voice
)
with gr.Column():
candidates = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
seed = gr.Number(value=0, precision=0, label="Seed")
preset = gr.Radio(
["Ultra Fast", "Fast", "Standard", "High Quality"],
label="Preset",
type="value",
)
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" )
num_autoregressive_samples = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Samples")
diffusion_iterations = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations")
@ -275,19 +261,7 @@ def setup_gradio():
breathing_room = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size")
diffusion_sampler = gr.Radio(
["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"],
value="P",
label="Diffusion Samplers",
type="value",
)
preset.change(fn=update_presets,
inputs=preset,
outputs=[
num_autoregressive_samples,
diffusion_iterations,
],
)
value="P", label="Diffusion Samplers", type="value" )
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings")
reset_generation_settings_button = gr.Button(value="Reset to Default")
with gr.Column(visible=False) as col:
@ -300,12 +274,6 @@ def setup_gradio():
length_penalty = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty")
repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty")
cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K")
show_experimental_settings.change(
fn=lambda x: gr.update(visible=x),
inputs=show_experimental_settings,
outputs=experimental_column
)
with gr.Column():
submit = gr.Button(value="Generate")
stop = gr.Button(value="Stop")
@ -315,33 +283,13 @@ def setup_gradio():
output_audio = gr.Audio(label="Output")
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False)
output_pick = gr.Button(value="Select Candidate", visible=False)
with gr.Tab("History"):
with gr.Row():
with gr.Column():
headers = {
"Name": "",
"Samples": "num_autoregressive_samples",
"Iterations": "diffusion_iterations",
"Temp.": "temperature",
"Sampler": "diffusion_sampler",
"CVVP": "cvvp_weight",
"Top P": "top_p",
"Diff. Temp.": "diffusion_temperature",
"Len Pen": "length_penalty",
"Rep Pen": "repetition_penalty",
"Cond-Free K": "cond_free_k",
"Time": "time",
}
history_info = gr.Dataframe(label="Results", headers=list(headers.keys()))
history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys()))
with gr.Row():
with gr.Column():
history_voices = gr.Dropdown(
get_voice_list("./results/"),
label="Voice",
type="value",
)
history_voices = gr.Dropdown(choices=get_voice_list("./results/"), label="Voice", type="value")
history_view_results_button = gr.Button(value="View Files")
with gr.Column():
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True)
@ -349,51 +297,16 @@ def setup_gradio():
with gr.Column():
history_audio = gr.Audio()
history_copy_settings_button = gr.Button(value="Copy Settings")
history_view_results_button.click(
fn=history_view_results,
inputs=history_voices,
outputs=[
history_info,
history_results_list,
]
)
history_view_result_button.click(
fn=lambda voice, file: f"./results/{voice}/{file}",
inputs=[
history_voices,
history_results_list,
],
outputs=history_audio
)
with gr.Tab("Utilities"):
with gr.Row():
with gr.Column():
audio_in = gr.File(type="file", label="Audio Input", file_types=["audio"])
copy_button = gr.Button(value="Copy Settings")
audio_in = gr.Files(type="file", label="Audio Input", file_types=["audio"])
import_voice_name = gr.Textbox(label="Voice Name")
import_voice_button = gr.Button(value="Import Voice")
with gr.Column():
metadata_out = gr.JSON(label="Audio Metadata")
latents_out = gr.File(type="binary", label="Voice Latents")
audio_in.upload(
fn=read_generate_settings_proxy,
inputs=audio_in,
outputs=[
metadata_out,
latents_out,
import_voice_name
]
)
import_voice_button.click(
fn=import_voice,
inputs=[
audio_in,
import_voice_name,
]
)
metadata_out = gr.JSON(label="Audio Metadata", visible=False)
copy_button = gr.Button(value="Copy Settings", visible=False)
latents_out = gr.File(type="binary", label="Voice Latents", visible=False)
with gr.Tab("Training"):
with gr.Tab("Prepare Dataset"):
with gr.Row():
@ -402,16 +315,9 @@ def setup_gradio():
gr.Dropdown( get_voice_list(), label="Dataset Source", type="value" ),
gr.Textbox(label="Language", placeholder="English")
]
dataset_voices = dataset_settings[0]
with gr.Column():
prepare_dataset_button = gr.Button(value="Prepare")
prepare_dataset_button.click(
prepare_dataset_proxy,
inputs=dataset_settings,
outputs=None
)
with gr.Column():
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
with gr.Tab("Generate Configuration"):
with gr.Row():
with gr.Column():
@ -421,8 +327,6 @@ def setup_gradio():
gr.Number(label="Print Frequency", value=50),
gr.Number(label="Save Frequency", value=50),
]
save_yaml_button = gr.Button(value="Save Training Configuration")
with gr.Column():
training_settings = training_settings + [
gr.Textbox(label="Training Name", placeholder="finetune"),
gr.Textbox(label="Dataset Name", placeholder="finetune"),
@ -430,24 +334,18 @@ def setup_gradio():
gr.Textbox(label="Validation Name", placeholder="finetune"),
gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"),
]
save_yaml_button.click(save_training_settings,
inputs=training_settings,
outputs=None
)
with gr.Tab("Train"):
with gr.Column():
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
save_yaml_button = gr.Button(value="Save Training Configuration")
with gr.Tab("Run Training"):
with gr.Row():
with gr.Column():
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
refresh_configs = gr.Button(value="Refresh Configurations")
train = gr.Button(value="Train")
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
train.click(run_training,
inputs=training_configs,
outputs=None
)
start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop")
with gr.Column():
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
with gr.Tab("Settings"):
with gr.Row():
exec_inputs = []
@ -465,23 +363,22 @@ def setup_gradio():
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
gr.Checkbox(label="Defer TTS Load", value=args.defer_tts_load),
gr.Textbox(label="Device Override", value=args.device_override),
gr.Dropdown(label="Whisper Model", value=args.whisper_model, choices=["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"]),
]
gr.Button(value="Check for Updates").click(check_for_updates)
gr.Button(value="Reload TTS").click(reload_tts)
with gr.Column():
exec_inputs = exec_inputs + [
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count),
gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate),
gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume),
gr.Dropdown(label="Whisper Model", value=args.whisper_model, choices=["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"]),
]
gr.Button(value="Check for Updates").click(check_for_updates)
gr.Button(value="Reload TTS").click(reload_tts)
for i in exec_inputs:
i.change(
fn=export_exec_settings,
inputs=exec_inputs
)
i.change( fn=export_exec_settings, inputs=exec_inputs )
# console_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
input_settings = [
text,
@ -507,11 +404,76 @@ def setup_gradio():
experimental_checkboxes,
]
history_view_results_button.click(
fn=history_view_results,
inputs=history_voices,
outputs=[
history_info,
history_results_list,
]
)
history_view_result_button.click(
fn=lambda voice, file: f"./results/{voice}/{file}",
inputs=[
history_voices,
history_results_list,
],
outputs=history_audio
)
audio_in.upload(
fn=read_generate_settings_proxy,
inputs=audio_in,
outputs=[
metadata_out,
copy_button,
latents_out,
import_voice_name
]
)
import_voice_button.click(
fn=import_voices_proxy,
inputs=[
audio_in,
import_voice_name,
],
outputs=import_voice_name #console_output
)
show_experimental_settings.change(
fn=lambda x: gr.update(visible=x),
inputs=show_experimental_settings,
outputs=experimental_column
)
preset.change(fn=update_presets,
inputs=preset,
outputs=[
num_autoregressive_samples,
diffusion_iterations,
],
)
recompute_voice_latents.click(compute_latents,
inputs=[
voice,
voice_latents_chunks,
],
outputs=voice,
)
prompt.change(fn=lambda value: gr.update(value="Custom"),
inputs=prompt,
outputs=emotion
)
mic_audio.change(fn=lambda value: gr.update(value="microphone"),
inputs=mic_audio,
outputs=voice
)
refresh_voices.click(update_voices,
inputs=None,
outputs=[
voice,
dataset_voices,
dataset_settings[0],
history_voices
]
)
@ -552,6 +514,25 @@ def setup_gradio():
outputs=input_settings
)
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
start_training_button.click(run_training,
inputs=training_configs,
outputs=training_output #console_output
)
stop_training_button.click(stop_training,
inputs=None,
outputs=training_output #console_output
)
prepare_dataset_button.click(
prepare_dataset_proxy,
inputs=dataset_settings,
outputs=prepare_dataset_output #console_output
)
save_yaml_button.click(save_training_settings,
inputs=training_settings,
outputs=save_yaml_output #console_output
)
if os.path.isfile('./config/generate.json'):
ui.load(import_generate_settings, inputs=None, outputs=input_settings)

@ -0,0 +1,4 @@
call .\venv\Scripts\activate.bat
python ./src/train.py -opt "%1"
deactivate
pause

@ -0,0 +1,3 @@
source ./venv/bin/activate
python3 ./src/train.py -opt "$1"
deactivate