forked from mrq/ai-voice-cloning
a bit of UI cleanup, import multiple audio files at once, actually shows progress when importing voices, hides audio metadata / latents if no generated settings are detected, preparing datasets shows its progress, saving a training YAML shows a message when done, training now works within the web UI, training output shows to web UI, provided notebook is cleaned up and uses a venv, etc.
This commit is contained in:
parent
c75d0bc5da
commit
d5c1433268
|
@ -3,10 +3,7 @@
|
|||
"nbformat_minor":0,
|
||||
"metadata":{
|
||||
"colab":{
|
||||
"private_outputs":true,
|
||||
"provenance":[
|
||||
|
||||
]
|
||||
"private_outputs":true
|
||||
},
|
||||
"kernelspec":{
|
||||
"name":"python3",
|
||||
|
@ -40,41 +37,62 @@
|
|||
"source":[
|
||||
"!git clone https://git.ecker.tech/mrq/ai-voice-cloning/\n",
|
||||
"%cd ai-voice-cloning\n",
|
||||
"!apt install python3.8-venv\n",
|
||||
"!python -m venv venv\n",
|
||||
"!source ./venv/bin/activate\n",
|
||||
"!git clone https://git.ecker.tech/mrq/DL-Art-School dlas\n",
|
||||
"!python -m pip install --upgrade pip\n",
|
||||
"!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116\n",
|
||||
"!python -m pip install -r ./requirements.txt\n",
|
||||
"!git clone https://git.ecker.tech/mrq/DL-Art-School dlas\n",
|
||||
"!python -m pip install -r ./dlas/requirements.txt"
|
||||
"!python -m pip install -r ./dlas/requirements.txt\n",
|
||||
"!python -m pip install -r ./requirements.txt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type":"markdown",
|
||||
"source":[
|
||||
"# Restart Runtime Before Proceeding"
|
||||
"# Update Repos"
|
||||
],
|
||||
"metadata":{
|
||||
"id":"TXFyLVLA48S5"
|
||||
"id":"IzrGt5IcHlAD"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type":"code",
|
||||
"source":[
|
||||
"# colab requires the runtime to restart before use\n",
|
||||
"exit()"
|
||||
"%cd /content/ai-voice-cloning/dlas\n",
|
||||
"!git reset --hard HEAD\n",
|
||||
"!git pull\n",
|
||||
"%cd ..\n",
|
||||
"!git reset --hard HEAD\n",
|
||||
"!git pull\n",
|
||||
"!python -m pip install ffmpeg ffmpeg-python"
|
||||
],
|
||||
"metadata":{
|
||||
"id":"FVUOtSASCSJ8"
|
||||
"id":"3DktoOXSHmtw"
|
||||
},
|
||||
"execution_count":null,
|
||||
"outputs":[
|
||||
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type":"markdown",
|
||||
"source":[
|
||||
"# Mount Drive"
|
||||
],
|
||||
"metadata":{
|
||||
"id":"2Y4t9zDIZMTg"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type":"code",
|
||||
"source":[
|
||||
"from google.colab import drive\n",
|
||||
"drive.mount('/content/drive')"
|
||||
"drive.mount('/content/drive')\n",
|
||||
"\n",
|
||||
"%cd /content/ai-voice-cloning\n",
|
||||
"!rm -r ./training\n",
|
||||
"!ln -s /content/drive/MyDrive/training/"
|
||||
],
|
||||
"metadata":{
|
||||
"id":"SGt9gyvubveT"
|
||||
|
@ -97,6 +115,8 @@
|
|||
"cell_type":"code",
|
||||
"source":[
|
||||
"%cd /content/ai-voice-cloning\n",
|
||||
"!python -m venv venv\n",
|
||||
"!source ./venv/bin/activate\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
|
@ -117,7 +137,7 @@
|
|||
"\n",
|
||||
"webui = setup_gradio()\n",
|
||||
"tts = setup_tortoise()\n",
|
||||
"webui.launch(share=True, prevent_thread_lock=True, debug=True, height=1000)\n",
|
||||
"webui.launch(share=True, prevent_thread_lock=True, height=1000)\n",
|
||||
"webui.block_thread()"
|
||||
],
|
||||
"metadata":{
|
||||
|
@ -140,8 +160,9 @@
|
|||
{
|
||||
"cell_type":"code",
|
||||
"source":[
|
||||
"# This is in case you can't get training through the web UI\n",
|
||||
"%cd /content/ai-voice-cloning\n",
|
||||
"!python ./src/train.py -opt ./training/finetune.yaml"
|
||||
"!python ./dlas/codes/train.py -opt ./training/finetune.yaml"
|
||||
],
|
||||
"metadata":{
|
||||
"id":"-KayB8klA5tY"
|
||||
|
@ -167,8 +188,9 @@
|
|||
"!apt install -y p7zip-full\n",
|
||||
"from datetime import datetime\n",
|
||||
"timestamp = datetime.now().strftime('%m-%d-%Y_%H:%M:%S')\n",
|
||||
"!mkdir -p \"../{timestamp}\"\n",
|
||||
"!mv ./results/* \"../{timestamp}/.\"\n",
|
||||
"!mkdir -p \"../{timestamp}/results\"\n",
|
||||
"!mv ./results/* \"../{timestamp}/results/.\"\n",
|
||||
"!mv ./training/* \"../{timestamp}/training/.\"\n",
|
||||
"!7z a -t7z -m0=lzma2 -mx=9 -mfb=64 -md=32m -ms=on \"../{timestamp}.7z\" \"../{timestamp}/\"\n",
|
||||
"!ls ~/\n",
|
||||
"!echo \"Finished zipping, archive is available at {timestamp}.7z\""
|
||||
|
|
59
src/train.py
59
src/train.py
|
@ -25,32 +25,37 @@ from utils import util, options as option
|
|||
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py
|
||||
# I'll clean it up better
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
if args.launcher != 'none':
|
||||
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
||||
if 'gpu_ids' in opt.keys():
|
||||
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||||
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
||||
trainer = tr.Trainer()
|
||||
def train(yaml, launcher='none'):
|
||||
opt = option.parse(yaml, is_train=True)
|
||||
if launcher != 'none':
|
||||
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
||||
if 'gpu_ids' in opt.keys():
|
||||
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||||
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
||||
trainer = tr.Trainer()
|
||||
|
||||
#### distributed training settings
|
||||
if args.launcher == 'none': # disabled distributed training
|
||||
opt['dist'] = False
|
||||
trainer.rank = -1
|
||||
if len(opt['gpu_ids']) == 1:
|
||||
torch.cuda.set_device(opt['gpu_ids'][0])
|
||||
print('Disabled distributed training.')
|
||||
else:
|
||||
opt['dist'] = True
|
||||
init_dist('nccl')
|
||||
trainer.world_size = torch.distributed.get_world_size()
|
||||
trainer.rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
#### distributed training settings
|
||||
if launcher == 'none': # disabled distributed training
|
||||
opt['dist'] = False
|
||||
trainer.rank = -1
|
||||
if len(opt['gpu_ids']) == 1:
|
||||
torch.cuda.set_device(opt['gpu_ids'][0])
|
||||
print('Disabled distributed training.')
|
||||
else:
|
||||
opt['dist'] = True
|
||||
init_dist('nccl')
|
||||
trainer.world_size = torch.distributed.get_world_size()
|
||||
trainer.rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
|
||||
trainer.init(args.opt, opt, args.launcher)
|
||||
trainer.do_training()
|
||||
trainer.init(yaml, opt, launcher)
|
||||
trainer.do_training()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
|
||||
train(args.opt, args.launcher)
|
169
src/utils.py
169
src/utils.py
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
|
||||
if 'XDG_CACHE_HOME' not in os.environ:
|
||||
os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/'))
|
||||
|
||||
|
@ -15,7 +14,9 @@ import json
|
|||
import base64
|
||||
import re
|
||||
import urllib.request
|
||||
import signal
|
||||
|
||||
import tqdm
|
||||
import torch
|
||||
import torchaudio
|
||||
import music_tag
|
||||
|
@ -90,6 +91,8 @@ def setup_args():
|
|||
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
|
||||
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
|
||||
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
|
||||
|
||||
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.embed_output_metadata = not args.no_embed_output_metadata
|
||||
|
@ -427,20 +430,37 @@ def generate(
|
|||
|
||||
import subprocess
|
||||
|
||||
training_process = None
|
||||
def run_training(config_path):
|
||||
print("Unloading TTS to save VRAM.")
|
||||
global tts
|
||||
del tts
|
||||
tts = None
|
||||
|
||||
cmd = ["python", "./src/train.py", "-opt", config_path]
|
||||
global training_process
|
||||
torch.multiprocessing.freeze_support()
|
||||
|
||||
cmd = [f'train.{"bat" if args.os == "windows" else "sh"}', config_path]
|
||||
print("Spawning process: ", " ".join(cmd))
|
||||
subprocess.run(cmd, env=os.environ.copy(), shell=True)
|
||||
"""
|
||||
from train import train
|
||||
train(config)
|
||||
"""
|
||||
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
buffer=[]
|
||||
for line in iter(training_process.stdout.readline, ""):
|
||||
buffer.append(line)
|
||||
yield "".join(buffer)
|
||||
|
||||
training_process.stdout.close()
|
||||
return_code = training_process.wait()
|
||||
training_process = None
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, cmd)
|
||||
|
||||
|
||||
def stop_training():
|
||||
if training_process is None:
|
||||
return "No training in progress"
|
||||
training_process.kill()
|
||||
training_process = None
|
||||
return "Training cancelled"
|
||||
|
||||
def setup_voicefixer(restart=False):
|
||||
global voicefixer
|
||||
|
@ -485,19 +505,23 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
|
|||
"validation_name": validation_name if validation_name else "finetune",
|
||||
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
||||
}
|
||||
outfile = f'./training/{settings["name"]}.yaml'
|
||||
|
||||
with open(f'./training/.template.yaml', 'r', encoding="utf-8") as f:
|
||||
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
|
||||
yaml = f.read()
|
||||
|
||||
for k in settings:
|
||||
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
|
||||
|
||||
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
|
||||
|
||||
with open(outfile, 'w', encoding="utf-8") as f:
|
||||
f.write(yaml)
|
||||
|
||||
def prepare_dataset( files, outdir, language=None ):
|
||||
return f"Training settings saved to: {outfile}"
|
||||
|
||||
def prepare_dataset( files, outdir, language=None, progress=None ):
|
||||
global whisper_model
|
||||
if whisper_model is None:
|
||||
notify_progress(f"Loading Whisper model: {args.whisper_model}", progress)
|
||||
whisper_model = whisper.load_model(args.whisper_model)
|
||||
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
@ -506,7 +530,7 @@ def prepare_dataset( files, outdir, language=None ):
|
|||
results = {}
|
||||
transcription = []
|
||||
|
||||
for file in files:
|
||||
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||
print(f"Transcribing file: {file}")
|
||||
|
||||
result = whisper_model.transcribe(file, language=language if language else "English")
|
||||
|
@ -517,7 +541,7 @@ def prepare_dataset( files, outdir, language=None ):
|
|||
waveform, sampling_rate = torchaudio.load(file)
|
||||
num_channels, num_frames = waveform.shape
|
||||
|
||||
for segment in result['segments']:
|
||||
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
||||
start = int(segment['start'] * sampling_rate)
|
||||
end = int(segment['end'] * sampling_rate)
|
||||
|
||||
|
@ -535,66 +559,74 @@ def prepare_dataset( files, outdir, language=None ):
|
|||
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
|
||||
f.write("\n".join(transcription))
|
||||
|
||||
return f"Processed dataset to: {outdir}"
|
||||
|
||||
def reset_generation_settings():
|
||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps({}, indent='\t') )
|
||||
return import_generate_settings()
|
||||
|
||||
def import_voice(file, saveAs = None):
|
||||
def import_voices(files, saveAs=None, progress=None):
|
||||
global args
|
||||
|
||||
j, latents = read_generate_settings(file, read_latents=True)
|
||||
|
||||
if j is not None and saveAs is None:
|
||||
saveAs = j['voice']
|
||||
if saveAs is None or saveAs == "":
|
||||
raise Exception("Specify a voice name")
|
||||
if not isinstance(files, list):
|
||||
files = [files]
|
||||
|
||||
outdir = f'{get_voice_dir()}/{saveAs}/'
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
if latents:
|
||||
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
||||
f.write(latents)
|
||||
latents = f'{outdir}/cond_latents.pth'
|
||||
print(f"Imported latents to {latents}")
|
||||
else:
|
||||
filename = file.name
|
||||
if filename[-4:] != ".wav":
|
||||
raise Exception("Please convert to a WAV first")
|
||||
for file in enumerate_progress(files, desc="Importing voice files", progress=progress):
|
||||
j, latents = read_generate_settings(file, read_latents=True)
|
||||
|
||||
if j is not None and saveAs is None:
|
||||
saveAs = j['voice']
|
||||
if saveAs is None or saveAs == "":
|
||||
raise Exception("Specify a voice name")
|
||||
|
||||
path = f"{outdir}/{os.path.basename(filename)}"
|
||||
waveform, sampling_rate = torchaudio.load(filename)
|
||||
outdir = f'{get_voice_dir()}/{saveAs}/'
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
if args.voice_fixer and voicefixer is not None:
|
||||
# resample to best bandwidth since voicefixer will do it anyways through librosa
|
||||
if sampling_rate != 44100:
|
||||
print(f"Resampling imported voice sample: {path}")
|
||||
resampler = torchaudio.transforms.Resample(
|
||||
sampling_rate,
|
||||
44100,
|
||||
lowpass_filter_width=16,
|
||||
rolloff=0.85,
|
||||
resampling_method="kaiser_window",
|
||||
beta=8.555504641634386,
|
||||
)
|
||||
waveform = resampler(waveform)
|
||||
sampling_rate = 44100
|
||||
|
||||
torchaudio.save(path, waveform, sampling_rate)
|
||||
|
||||
print(f"Running 'voicefixer' on voice sample: {path}")
|
||||
voicefixer.restore(
|
||||
input = path,
|
||||
output = path,
|
||||
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
||||
#mode=mode,
|
||||
)
|
||||
if latents:
|
||||
print(f"Importing latents to {latents}")
|
||||
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
||||
f.write(latents)
|
||||
latents = f'{outdir}/cond_latents.pth'
|
||||
print(f"Imported latents to {latents}")
|
||||
else:
|
||||
torchaudio.save(path, waveform, sampling_rate)
|
||||
filename = file.name
|
||||
if filename[-4:] != ".wav":
|
||||
raise Exception("Please convert to a WAV first")
|
||||
|
||||
path = f"{outdir}/{os.path.basename(filename)}"
|
||||
print(f"Importing voice to {path}")
|
||||
|
||||
print(f"Imported voice to {path}")
|
||||
waveform, sampling_rate = torchaudio.load(filename)
|
||||
|
||||
if args.voice_fixer and voicefixer is not None:
|
||||
# resample to best bandwidth since voicefixer will do it anyways through librosa
|
||||
if sampling_rate != 44100:
|
||||
print(f"Resampling imported voice sample: {path}")
|
||||
resampler = torchaudio.transforms.Resample(
|
||||
sampling_rate,
|
||||
44100,
|
||||
lowpass_filter_width=16,
|
||||
rolloff=0.85,
|
||||
resampling_method="kaiser_window",
|
||||
beta=8.555504641634386,
|
||||
)
|
||||
waveform = resampler(waveform)
|
||||
sampling_rate = 44100
|
||||
|
||||
torchaudio.save(path, waveform, sampling_rate)
|
||||
|
||||
print(f"Running 'voicefixer' on voice sample: {path}")
|
||||
voicefixer.restore(
|
||||
input = path,
|
||||
output = path,
|
||||
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
||||
#mode=mode,
|
||||
)
|
||||
else:
|
||||
torchaudio.save(path, waveform, sampling_rate)
|
||||
|
||||
print(f"Imported voice to {path}")
|
||||
|
||||
def import_generate_settings(file="./config/generate.json"):
|
||||
settings, _ = read_generate_settings(file, read_latents=False)
|
||||
|
@ -759,4 +791,21 @@ def read_generate_settings(file, read_latents=True, read_json=True):
|
|||
return (
|
||||
j,
|
||||
latents,
|
||||
)
|
||||
)
|
||||
|
||||
def enumerate_progress(iterable, desc=None, progress=None, verbose=None):
|
||||
if verbose and desc is not None:
|
||||
print(desc)
|
||||
|
||||
if progress is None:
|
||||
return tqdm(iterable, disable=not verbose)
|
||||
return progress.tqdm(iterable, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc, track_tqdm=True)
|
||||
|
||||
def notify_progress(message, progress=None, verbose=True):
|
||||
if verbose:
|
||||
print(message)
|
||||
|
||||
if progress is None:
|
||||
return
|
||||
|
||||
progress(0, desc=message)
|
291
src/webui.py
291
src/webui.py
|
@ -135,6 +135,21 @@ def get_training_configs():
|
|||
def update_training_configs():
|
||||
return gr.update(choices=get_training_configs())
|
||||
|
||||
history_headers = {
|
||||
"Name": "",
|
||||
"Samples": "num_autoregressive_samples",
|
||||
"Iterations": "diffusion_iterations",
|
||||
"Temp.": "temperature",
|
||||
"Sampler": "diffusion_sampler",
|
||||
"CVVP": "cvvp_weight",
|
||||
"Top P": "top_p",
|
||||
"Diff. Temp.": "diffusion_temperature",
|
||||
"Len Pen": "length_penalty",
|
||||
"Rep Pen": "repetition_penalty",
|
||||
"Cond-Free K": "cond_free_k",
|
||||
"Time": "time",
|
||||
}
|
||||
|
||||
def history_view_results( voice ):
|
||||
results = []
|
||||
files = []
|
||||
|
@ -148,7 +163,7 @@ def history_view_results( voice ):
|
|||
continue
|
||||
|
||||
values = []
|
||||
for k in headers:
|
||||
for k in history_headers:
|
||||
v = file
|
||||
if k != "Name":
|
||||
v = metadata[headers[k]]
|
||||
|
@ -163,6 +178,10 @@ def history_view_results( voice ):
|
|||
gr.Dropdown.update(choices=sorted(files))
|
||||
)
|
||||
|
||||
def import_voices_proxy(files, name, progress=gr.Progress(track_tqdm=True)):
|
||||
import_voices(files, name, progress)
|
||||
return gr.update()
|
||||
|
||||
def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||
j, latents = read_generate_settings(file)
|
||||
|
||||
|
@ -175,13 +194,14 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
|||
latents = f'{outdir}/cond_latents.pth'
|
||||
|
||||
return (
|
||||
j,
|
||||
gr.update(value=j, visible=j is not None),
|
||||
gr.update(visible=j is not None),
|
||||
gr.update(value=latents, visible=latents is not None),
|
||||
None if j is None else j['voice']
|
||||
)
|
||||
|
||||
def prepare_dataset_proxy( voice, language ):
|
||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language )
|
||||
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
||||
|
||||
def update_voices():
|
||||
return (
|
||||
|
@ -222,52 +242,18 @@ def setup_gradio():
|
|||
with gr.Column():
|
||||
delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
|
||||
|
||||
emotion = gr.Radio(
|
||||
["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"],
|
||||
value="Custom",
|
||||
label="Emotion",
|
||||
type="value",
|
||||
interactive=True
|
||||
)
|
||||
emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
|
||||
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
|
||||
voice = gr.Dropdown(
|
||||
get_voice_list(),
|
||||
label="Voice",
|
||||
type="value",
|
||||
)
|
||||
mic_audio = gr.Audio(
|
||||
label="Microphone Source",
|
||||
source="microphone",
|
||||
type="filepath",
|
||||
)
|
||||
voice = gr.Dropdown(get_voice_list(), label="Voice", type="value")
|
||||
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
|
||||
refresh_voices = gr.Button(value="Refresh Voice List")
|
||||
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
|
||||
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
|
||||
recompute_voice_latents.click(compute_latents,
|
||||
inputs=[
|
||||
voice,
|
||||
voice_latents_chunks,
|
||||
],
|
||||
outputs=voice,
|
||||
)
|
||||
|
||||
prompt.change(fn=lambda value: gr.update(value="Custom"),
|
||||
inputs=prompt,
|
||||
outputs=emotion
|
||||
)
|
||||
mic_audio.change(fn=lambda value: gr.update(value="microphone"),
|
||||
inputs=mic_audio,
|
||||
outputs=voice
|
||||
)
|
||||
with gr.Column():
|
||||
candidates = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
|
||||
seed = gr.Number(value=0, precision=0, label="Seed")
|
||||
|
||||
preset = gr.Radio(
|
||||
["Ultra Fast", "Fast", "Standard", "High Quality"],
|
||||
label="Preset",
|
||||
type="value",
|
||||
)
|
||||
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" )
|
||||
num_autoregressive_samples = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Samples")
|
||||
diffusion_iterations = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations")
|
||||
|
||||
|
@ -275,19 +261,7 @@ def setup_gradio():
|
|||
breathing_room = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size")
|
||||
diffusion_sampler = gr.Radio(
|
||||
["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"],
|
||||
value="P",
|
||||
label="Diffusion Samplers",
|
||||
type="value",
|
||||
)
|
||||
|
||||
preset.change(fn=update_presets,
|
||||
inputs=preset,
|
||||
outputs=[
|
||||
num_autoregressive_samples,
|
||||
diffusion_iterations,
|
||||
],
|
||||
)
|
||||
|
||||
value="P", label="Diffusion Samplers", type="value" )
|
||||
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings")
|
||||
reset_generation_settings_button = gr.Button(value="Reset to Default")
|
||||
with gr.Column(visible=False) as col:
|
||||
|
@ -300,12 +274,6 @@ def setup_gradio():
|
|||
length_penalty = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty")
|
||||
repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty")
|
||||
cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K")
|
||||
|
||||
show_experimental_settings.change(
|
||||
fn=lambda x: gr.update(visible=x),
|
||||
inputs=show_experimental_settings,
|
||||
outputs=experimental_column
|
||||
)
|
||||
with gr.Column():
|
||||
submit = gr.Button(value="Generate")
|
||||
stop = gr.Button(value="Stop")
|
||||
|
@ -315,33 +283,13 @@ def setup_gradio():
|
|||
output_audio = gr.Audio(label="Output")
|
||||
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False)
|
||||
output_pick = gr.Button(value="Select Candidate", visible=False)
|
||||
|
||||
with gr.Tab("History"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
headers = {
|
||||
"Name": "",
|
||||
"Samples": "num_autoregressive_samples",
|
||||
"Iterations": "diffusion_iterations",
|
||||
"Temp.": "temperature",
|
||||
"Sampler": "diffusion_sampler",
|
||||
"CVVP": "cvvp_weight",
|
||||
"Top P": "top_p",
|
||||
"Diff. Temp.": "diffusion_temperature",
|
||||
"Len Pen": "length_penalty",
|
||||
"Rep Pen": "repetition_penalty",
|
||||
"Cond-Free K": "cond_free_k",
|
||||
"Time": "time",
|
||||
}
|
||||
history_info = gr.Dataframe(label="Results", headers=list(headers.keys()))
|
||||
history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys()))
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
history_voices = gr.Dropdown(
|
||||
get_voice_list("./results/"),
|
||||
label="Voice",
|
||||
type="value",
|
||||
)
|
||||
|
||||
history_voices = gr.Dropdown(choices=get_voice_list("./results/"), label="Voice", type="value")
|
||||
history_view_results_button = gr.Button(value="View Files")
|
||||
with gr.Column():
|
||||
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True)
|
||||
|
@ -349,51 +297,16 @@ def setup_gradio():
|
|||
with gr.Column():
|
||||
history_audio = gr.Audio()
|
||||
history_copy_settings_button = gr.Button(value="Copy Settings")
|
||||
|
||||
history_view_results_button.click(
|
||||
fn=history_view_results,
|
||||
inputs=history_voices,
|
||||
outputs=[
|
||||
history_info,
|
||||
history_results_list,
|
||||
]
|
||||
)
|
||||
history_view_result_button.click(
|
||||
fn=lambda voice, file: f"./results/{voice}/{file}",
|
||||
inputs=[
|
||||
history_voices,
|
||||
history_results_list,
|
||||
],
|
||||
outputs=history_audio
|
||||
)
|
||||
with gr.Tab("Utilities"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
audio_in = gr.File(type="file", label="Audio Input", file_types=["audio"])
|
||||
copy_button = gr.Button(value="Copy Settings")
|
||||
audio_in = gr.Files(type="file", label="Audio Input", file_types=["audio"])
|
||||
import_voice_name = gr.Textbox(label="Voice Name")
|
||||
import_voice_button = gr.Button(value="Import Voice")
|
||||
with gr.Column():
|
||||
metadata_out = gr.JSON(label="Audio Metadata")
|
||||
latents_out = gr.File(type="binary", label="Voice Latents")
|
||||
|
||||
audio_in.upload(
|
||||
fn=read_generate_settings_proxy,
|
||||
inputs=audio_in,
|
||||
outputs=[
|
||||
metadata_out,
|
||||
latents_out,
|
||||
import_voice_name
|
||||
]
|
||||
)
|
||||
|
||||
import_voice_button.click(
|
||||
fn=import_voice,
|
||||
inputs=[
|
||||
audio_in,
|
||||
import_voice_name,
|
||||
]
|
||||
)
|
||||
metadata_out = gr.JSON(label="Audio Metadata", visible=False)
|
||||
copy_button = gr.Button(value="Copy Settings", visible=False)
|
||||
latents_out = gr.File(type="binary", label="Voice Latents", visible=False)
|
||||
with gr.Tab("Training"):
|
||||
with gr.Tab("Prepare Dataset"):
|
||||
with gr.Row():
|
||||
|
@ -402,16 +315,9 @@ def setup_gradio():
|
|||
gr.Dropdown( get_voice_list(), label="Dataset Source", type="value" ),
|
||||
gr.Textbox(label="Language", placeholder="English")
|
||||
]
|
||||
dataset_voices = dataset_settings[0]
|
||||
|
||||
with gr.Column():
|
||||
prepare_dataset_button = gr.Button(value="Prepare")
|
||||
|
||||
prepare_dataset_button.click(
|
||||
prepare_dataset_proxy,
|
||||
inputs=dataset_settings,
|
||||
outputs=None
|
||||
)
|
||||
with gr.Column():
|
||||
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
with gr.Tab("Generate Configuration"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
|
@ -421,8 +327,6 @@ def setup_gradio():
|
|||
gr.Number(label="Print Frequency", value=50),
|
||||
gr.Number(label="Save Frequency", value=50),
|
||||
]
|
||||
save_yaml_button = gr.Button(value="Save Training Configuration")
|
||||
with gr.Column():
|
||||
training_settings = training_settings + [
|
||||
gr.Textbox(label="Training Name", placeholder="finetune"),
|
||||
gr.Textbox(label="Dataset Name", placeholder="finetune"),
|
||||
|
@ -430,24 +334,18 @@ def setup_gradio():
|
|||
gr.Textbox(label="Validation Name", placeholder="finetune"),
|
||||
gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"),
|
||||
]
|
||||
|
||||
save_yaml_button.click(save_training_settings,
|
||||
inputs=training_settings,
|
||||
outputs=None
|
||||
)
|
||||
with gr.Tab("Train"):
|
||||
with gr.Column():
|
||||
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
save_yaml_button = gr.Button(value="Save Training Configuration")
|
||||
with gr.Tab("Run Training"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
|
||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||
train = gr.Button(value="Train")
|
||||
|
||||
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
|
||||
train.click(run_training,
|
||||
inputs=training_configs,
|
||||
outputs=None
|
||||
)
|
||||
|
||||
start_training_button = gr.Button(value="Train")
|
||||
stop_training_button = gr.Button(value="Stop")
|
||||
with gr.Column():
|
||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
exec_inputs = []
|
||||
|
@ -465,23 +363,22 @@ def setup_gradio():
|
|||
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
|
||||
gr.Checkbox(label="Defer TTS Load", value=args.defer_tts_load),
|
||||
gr.Textbox(label="Device Override", value=args.device_override),
|
||||
gr.Dropdown(label="Whisper Model", value=args.whisper_model, choices=["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"]),
|
||||
]
|
||||
gr.Button(value="Check for Updates").click(check_for_updates)
|
||||
gr.Button(value="Reload TTS").click(reload_tts)
|
||||
with gr.Column():
|
||||
exec_inputs = exec_inputs + [
|
||||
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
|
||||
gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count),
|
||||
gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate),
|
||||
gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume),
|
||||
gr.Dropdown(label="Whisper Model", value=args.whisper_model, choices=["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"]),
|
||||
]
|
||||
gr.Button(value="Check for Updates").click(check_for_updates)
|
||||
gr.Button(value="Reload TTS").click(reload_tts)
|
||||
|
||||
for i in exec_inputs:
|
||||
i.change(
|
||||
fn=export_exec_settings,
|
||||
inputs=exec_inputs
|
||||
)
|
||||
i.change( fn=export_exec_settings, inputs=exec_inputs )
|
||||
|
||||
# console_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
|
||||
input_settings = [
|
||||
text,
|
||||
|
@ -507,11 +404,76 @@ def setup_gradio():
|
|||
experimental_checkboxes,
|
||||
]
|
||||
|
||||
history_view_results_button.click(
|
||||
fn=history_view_results,
|
||||
inputs=history_voices,
|
||||
outputs=[
|
||||
history_info,
|
||||
history_results_list,
|
||||
]
|
||||
)
|
||||
history_view_result_button.click(
|
||||
fn=lambda voice, file: f"./results/{voice}/{file}",
|
||||
inputs=[
|
||||
history_voices,
|
||||
history_results_list,
|
||||
],
|
||||
outputs=history_audio
|
||||
)
|
||||
audio_in.upload(
|
||||
fn=read_generate_settings_proxy,
|
||||
inputs=audio_in,
|
||||
outputs=[
|
||||
metadata_out,
|
||||
copy_button,
|
||||
latents_out,
|
||||
import_voice_name
|
||||
]
|
||||
)
|
||||
|
||||
import_voice_button.click(
|
||||
fn=import_voices_proxy,
|
||||
inputs=[
|
||||
audio_in,
|
||||
import_voice_name,
|
||||
],
|
||||
outputs=import_voice_name #console_output
|
||||
)
|
||||
show_experimental_settings.change(
|
||||
fn=lambda x: gr.update(visible=x),
|
||||
inputs=show_experimental_settings,
|
||||
outputs=experimental_column
|
||||
)
|
||||
preset.change(fn=update_presets,
|
||||
inputs=preset,
|
||||
outputs=[
|
||||
num_autoregressive_samples,
|
||||
diffusion_iterations,
|
||||
],
|
||||
)
|
||||
|
||||
recompute_voice_latents.click(compute_latents,
|
||||
inputs=[
|
||||
voice,
|
||||
voice_latents_chunks,
|
||||
],
|
||||
outputs=voice,
|
||||
)
|
||||
|
||||
prompt.change(fn=lambda value: gr.update(value="Custom"),
|
||||
inputs=prompt,
|
||||
outputs=emotion
|
||||
)
|
||||
mic_audio.change(fn=lambda value: gr.update(value="microphone"),
|
||||
inputs=mic_audio,
|
||||
outputs=voice
|
||||
)
|
||||
|
||||
refresh_voices.click(update_voices,
|
||||
inputs=None,
|
||||
outputs=[
|
||||
voice,
|
||||
dataset_voices,
|
||||
dataset_settings[0],
|
||||
history_voices
|
||||
]
|
||||
)
|
||||
|
@ -552,6 +514,25 @@ def setup_gradio():
|
|||
outputs=input_settings
|
||||
)
|
||||
|
||||
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
|
||||
start_training_button.click(run_training,
|
||||
inputs=training_configs,
|
||||
outputs=training_output #console_output
|
||||
)
|
||||
stop_training_button.click(stop_training,
|
||||
inputs=None,
|
||||
outputs=training_output #console_output
|
||||
)
|
||||
prepare_dataset_button.click(
|
||||
prepare_dataset_proxy,
|
||||
inputs=dataset_settings,
|
||||
outputs=prepare_dataset_output #console_output
|
||||
)
|
||||
save_yaml_button.click(save_training_settings,
|
||||
inputs=training_settings,
|
||||
outputs=save_yaml_output #console_output
|
||||
)
|
||||
|
||||
if os.path.isfile('./config/generate.json'):
|
||||
ui.load(import_generate_settings, inputs=None, outputs=input_settings)
|
||||
|
||||
|
|
4
train.bat
Executable file
4
train.bat
Executable file
|
@ -0,0 +1,4 @@
|
|||
call .\venv\Scripts\activate.bat
|
||||
python ./src/train.py -opt "%1"
|
||||
deactivate
|
||||
pause
|
3
train.sh
Executable file
3
train.sh
Executable file
|
@ -0,0 +1,3 @@
|
|||
source ./venv/bin/activate
|
||||
python3 ./src/train.py -opt "$1"
|
||||
deactivate
|
0
training/.gitkeep
Executable file
0
training/.gitkeep
Executable file
Loading…
Reference in New Issue
Block a user