forked from mrq/ai-voice-cloning
a bit of UI cleanup, import multiple audio files at once, actually shows progress when importing voices, hides audio metadata / latents if no generated settings are detected, preparing datasets shows its progress, saving a training YAML shows a message when done, training now works within the web UI, training output shows to web UI, provided notebook is cleaned up and uses a venv, etc.
This commit is contained in:
parent
c75d0bc5da
commit
d5c1433268
|
@ -3,10 +3,7 @@
|
||||||
"nbformat_minor":0,
|
"nbformat_minor":0,
|
||||||
"metadata":{
|
"metadata":{
|
||||||
"colab":{
|
"colab":{
|
||||||
"private_outputs":true,
|
"private_outputs":true
|
||||||
"provenance":[
|
|
||||||
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
"kernelspec":{
|
"kernelspec":{
|
||||||
"name":"python3",
|
"name":"python3",
|
||||||
|
@ -40,41 +37,62 @@
|
||||||
"source":[
|
"source":[
|
||||||
"!git clone https://git.ecker.tech/mrq/ai-voice-cloning/\n",
|
"!git clone https://git.ecker.tech/mrq/ai-voice-cloning/\n",
|
||||||
"%cd ai-voice-cloning\n",
|
"%cd ai-voice-cloning\n",
|
||||||
|
"!apt install python3.8-venv\n",
|
||||||
|
"!python -m venv venv\n",
|
||||||
|
"!source ./venv/bin/activate\n",
|
||||||
|
"!git clone https://git.ecker.tech/mrq/DL-Art-School dlas\n",
|
||||||
"!python -m pip install --upgrade pip\n",
|
"!python -m pip install --upgrade pip\n",
|
||||||
"!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116\n",
|
"!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116\n",
|
||||||
"!python -m pip install -r ./requirements.txt\n",
|
"!python -m pip install -r ./dlas/requirements.txt\n",
|
||||||
"!git clone https://git.ecker.tech/mrq/DL-Art-School dlas\n",
|
"!python -m pip install -r ./requirements.txt"
|
||||||
"!python -m pip install -r ./dlas/requirements.txt"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type":"markdown",
|
"cell_type":"markdown",
|
||||||
"source":[
|
"source":[
|
||||||
"# Restart Runtime Before Proceeding"
|
"# Update Repos"
|
||||||
],
|
],
|
||||||
"metadata":{
|
"metadata":{
|
||||||
"id":"TXFyLVLA48S5"
|
"id":"IzrGt5IcHlAD"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type":"code",
|
"cell_type":"code",
|
||||||
"source":[
|
"source":[
|
||||||
"# colab requires the runtime to restart before use\n",
|
"%cd /content/ai-voice-cloning/dlas\n",
|
||||||
"exit()"
|
"!git reset --hard HEAD\n",
|
||||||
|
"!git pull\n",
|
||||||
|
"%cd ..\n",
|
||||||
|
"!git reset --hard HEAD\n",
|
||||||
|
"!git pull\n",
|
||||||
|
"!python -m pip install ffmpeg ffmpeg-python"
|
||||||
],
|
],
|
||||||
"metadata":{
|
"metadata":{
|
||||||
"id":"FVUOtSASCSJ8"
|
"id":"3DktoOXSHmtw"
|
||||||
},
|
},
|
||||||
"execution_count":null,
|
"execution_count":null,
|
||||||
"outputs":[
|
"outputs":[
|
||||||
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type":"markdown",
|
||||||
|
"source":[
|
||||||
|
"# Mount Drive"
|
||||||
|
],
|
||||||
|
"metadata":{
|
||||||
|
"id":"2Y4t9zDIZMTg"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type":"code",
|
"cell_type":"code",
|
||||||
"source":[
|
"source":[
|
||||||
"from google.colab import drive\n",
|
"from google.colab import drive\n",
|
||||||
"drive.mount('/content/drive')"
|
"drive.mount('/content/drive')\n",
|
||||||
|
"\n",
|
||||||
|
"%cd /content/ai-voice-cloning\n",
|
||||||
|
"!rm -r ./training\n",
|
||||||
|
"!ln -s /content/drive/MyDrive/training/"
|
||||||
],
|
],
|
||||||
"metadata":{
|
"metadata":{
|
||||||
"id":"SGt9gyvubveT"
|
"id":"SGt9gyvubveT"
|
||||||
|
@ -97,6 +115,8 @@
|
||||||
"cell_type":"code",
|
"cell_type":"code",
|
||||||
"source":[
|
"source":[
|
||||||
"%cd /content/ai-voice-cloning\n",
|
"%cd /content/ai-voice-cloning\n",
|
||||||
|
"!python -m venv venv\n",
|
||||||
|
"!source ./venv/bin/activate\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import sys\n",
|
"import sys\n",
|
||||||
|
@ -117,7 +137,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"webui = setup_gradio()\n",
|
"webui = setup_gradio()\n",
|
||||||
"tts = setup_tortoise()\n",
|
"tts = setup_tortoise()\n",
|
||||||
"webui.launch(share=True, prevent_thread_lock=True, debug=True, height=1000)\n",
|
"webui.launch(share=True, prevent_thread_lock=True, height=1000)\n",
|
||||||
"webui.block_thread()"
|
"webui.block_thread()"
|
||||||
],
|
],
|
||||||
"metadata":{
|
"metadata":{
|
||||||
|
@ -140,8 +160,9 @@
|
||||||
{
|
{
|
||||||
"cell_type":"code",
|
"cell_type":"code",
|
||||||
"source":[
|
"source":[
|
||||||
|
"# This is in case you can't get training through the web UI\n",
|
||||||
"%cd /content/ai-voice-cloning\n",
|
"%cd /content/ai-voice-cloning\n",
|
||||||
"!python ./src/train.py -opt ./training/finetune.yaml"
|
"!python ./dlas/codes/train.py -opt ./training/finetune.yaml"
|
||||||
],
|
],
|
||||||
"metadata":{
|
"metadata":{
|
||||||
"id":"-KayB8klA5tY"
|
"id":"-KayB8klA5tY"
|
||||||
|
@ -167,8 +188,9 @@
|
||||||
"!apt install -y p7zip-full\n",
|
"!apt install -y p7zip-full\n",
|
||||||
"from datetime import datetime\n",
|
"from datetime import datetime\n",
|
||||||
"timestamp = datetime.now().strftime('%m-%d-%Y_%H:%M:%S')\n",
|
"timestamp = datetime.now().strftime('%m-%d-%Y_%H:%M:%S')\n",
|
||||||
"!mkdir -p \"../{timestamp}\"\n",
|
"!mkdir -p \"../{timestamp}/results\"\n",
|
||||||
"!mv ./results/* \"../{timestamp}/.\"\n",
|
"!mv ./results/* \"../{timestamp}/results/.\"\n",
|
||||||
|
"!mv ./training/* \"../{timestamp}/training/.\"\n",
|
||||||
"!7z a -t7z -m0=lzma2 -mx=9 -mfb=64 -md=32m -ms=on \"../{timestamp}.7z\" \"../{timestamp}/\"\n",
|
"!7z a -t7z -m0=lzma2 -mx=9 -mfb=64 -md=32m -ms=on \"../{timestamp}.7z\" \"../{timestamp}/\"\n",
|
||||||
"!ls ~/\n",
|
"!ls ~/\n",
|
||||||
"!echo \"Finished zipping, archive is available at {timestamp}.7z\""
|
"!echo \"Finished zipping, archive is available at {timestamp}.7z\""
|
||||||
|
|
21
src/train.py
21
src/train.py
|
@ -25,12 +25,9 @@ from utils import util, options as option
|
||||||
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py
|
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py
|
||||||
# I'll clean it up better
|
# I'll clean it up better
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
def train(yaml, launcher='none'):
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
opt = option.parse(yaml, is_train=True)
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
if launcher != 'none':
|
||||||
args = parser.parse_args()
|
|
||||||
opt = option.parse(args.opt, is_train=True)
|
|
||||||
if args.launcher != 'none':
|
|
||||||
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
||||||
if 'gpu_ids' in opt.keys():
|
if 'gpu_ids' in opt.keys():
|
||||||
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
||||||
|
@ -39,7 +36,7 @@ if args.launcher != 'none':
|
||||||
trainer = tr.Trainer()
|
trainer = tr.Trainer()
|
||||||
|
|
||||||
#### distributed training settings
|
#### distributed training settings
|
||||||
if args.launcher == 'none': # disabled distributed training
|
if launcher == 'none': # disabled distributed training
|
||||||
opt['dist'] = False
|
opt['dist'] = False
|
||||||
trainer.rank = -1
|
trainer.rank = -1
|
||||||
if len(opt['gpu_ids']) == 1:
|
if len(opt['gpu_ids']) == 1:
|
||||||
|
@ -52,5 +49,13 @@ else:
|
||||||
trainer.rank = torch.distributed.get_rank()
|
trainer.rank = torch.distributed.get_rank()
|
||||||
torch.cuda.set_device(torch.distributed.get_rank())
|
torch.cuda.set_device(torch.distributed.get_rank())
|
||||||
|
|
||||||
trainer.init(args.opt, opt, args.launcher)
|
trainer.init(yaml, opt, launcher)
|
||||||
trainer.do_training()
|
trainer.do_training()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
||||||
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
train(args.opt, args.launcher)
|
79
src/utils.py
79
src/utils.py
|
@ -1,5 +1,4 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
if 'XDG_CACHE_HOME' not in os.environ:
|
if 'XDG_CACHE_HOME' not in os.environ:
|
||||||
os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/'))
|
os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/'))
|
||||||
|
|
||||||
|
@ -15,7 +14,9 @@ import json
|
||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
import signal
|
||||||
|
|
||||||
|
import tqdm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import music_tag
|
import music_tag
|
||||||
|
@ -90,6 +91,8 @@ def setup_args():
|
||||||
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
|
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
|
||||||
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
|
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
|
||||||
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
|
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
|
||||||
|
|
||||||
|
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.embed_output_metadata = not args.no_embed_output_metadata
|
args.embed_output_metadata = not args.no_embed_output_metadata
|
||||||
|
@ -427,20 +430,37 @@ def generate(
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
training_process = None
|
||||||
def run_training(config_path):
|
def run_training(config_path):
|
||||||
print("Unloading TTS to save VRAM.")
|
print("Unloading TTS to save VRAM.")
|
||||||
global tts
|
global tts
|
||||||
del tts
|
del tts
|
||||||
tts = None
|
tts = None
|
||||||
|
|
||||||
cmd = ["python", "./src/train.py", "-opt", config_path]
|
global training_process
|
||||||
|
torch.multiprocessing.freeze_support()
|
||||||
|
|
||||||
|
cmd = [f'train.{"bat" if args.os == "windows" else "sh"}', config_path]
|
||||||
print("Spawning process: ", " ".join(cmd))
|
print("Spawning process: ", " ".join(cmd))
|
||||||
subprocess.run(cmd, env=os.environ.copy(), shell=True)
|
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
"""
|
buffer=[]
|
||||||
from train import train
|
for line in iter(training_process.stdout.readline, ""):
|
||||||
train(config)
|
buffer.append(line)
|
||||||
"""
|
yield "".join(buffer)
|
||||||
|
|
||||||
|
training_process.stdout.close()
|
||||||
|
return_code = training_process.wait()
|
||||||
|
training_process = None
|
||||||
|
if return_code:
|
||||||
|
raise subprocess.CalledProcessError(return_code, cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def stop_training():
|
||||||
|
if training_process is None:
|
||||||
|
return "No training in progress"
|
||||||
|
training_process.kill()
|
||||||
|
training_process = None
|
||||||
|
return "Training cancelled"
|
||||||
|
|
||||||
def setup_voicefixer(restart=False):
|
def setup_voicefixer(restart=False):
|
||||||
global voicefixer
|
global voicefixer
|
||||||
|
@ -485,19 +505,23 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
|
||||||
"validation_name": validation_name if validation_name else "finetune",
|
"validation_name": validation_name if validation_name else "finetune",
|
||||||
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
||||||
}
|
}
|
||||||
|
outfile = f'./training/{settings["name"]}.yaml'
|
||||||
|
|
||||||
with open(f'./training/.template.yaml', 'r', encoding="utf-8") as f:
|
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
|
||||||
yaml = f.read()
|
yaml = f.read()
|
||||||
|
|
||||||
for k in settings:
|
for k in settings:
|
||||||
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
|
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
|
||||||
|
|
||||||
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
|
with open(outfile, 'w', encoding="utf-8") as f:
|
||||||
f.write(yaml)
|
f.write(yaml)
|
||||||
|
|
||||||
def prepare_dataset( files, outdir, language=None ):
|
return f"Training settings saved to: {outfile}"
|
||||||
|
|
||||||
|
def prepare_dataset( files, outdir, language=None, progress=None ):
|
||||||
global whisper_model
|
global whisper_model
|
||||||
if whisper_model is None:
|
if whisper_model is None:
|
||||||
|
notify_progress(f"Loading Whisper model: {args.whisper_model}", progress)
|
||||||
whisper_model = whisper.load_model(args.whisper_model)
|
whisper_model = whisper.load_model(args.whisper_model)
|
||||||
|
|
||||||
os.makedirs(outdir, exist_ok=True)
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
@ -506,7 +530,7 @@ def prepare_dataset( files, outdir, language=None ):
|
||||||
results = {}
|
results = {}
|
||||||
transcription = []
|
transcription = []
|
||||||
|
|
||||||
for file in files:
|
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||||
print(f"Transcribing file: {file}")
|
print(f"Transcribing file: {file}")
|
||||||
|
|
||||||
result = whisper_model.transcribe(file, language=language if language else "English")
|
result = whisper_model.transcribe(file, language=language if language else "English")
|
||||||
|
@ -517,7 +541,7 @@ def prepare_dataset( files, outdir, language=None ):
|
||||||
waveform, sampling_rate = torchaudio.load(file)
|
waveform, sampling_rate = torchaudio.load(file)
|
||||||
num_channels, num_frames = waveform.shape
|
num_channels, num_frames = waveform.shape
|
||||||
|
|
||||||
for segment in result['segments']:
|
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
||||||
start = int(segment['start'] * sampling_rate)
|
start = int(segment['start'] * sampling_rate)
|
||||||
end = int(segment['end'] * sampling_rate)
|
end = int(segment['end'] * sampling_rate)
|
||||||
|
|
||||||
|
@ -535,14 +559,20 @@ def prepare_dataset( files, outdir, language=None ):
|
||||||
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
|
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
|
||||||
f.write("\n".join(transcription))
|
f.write("\n".join(transcription))
|
||||||
|
|
||||||
|
return f"Processed dataset to: {outdir}"
|
||||||
|
|
||||||
def reset_generation_settings():
|
def reset_generation_settings():
|
||||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps({}, indent='\t') )
|
f.write(json.dumps({}, indent='\t') )
|
||||||
return import_generate_settings()
|
return import_generate_settings()
|
||||||
|
|
||||||
def import_voice(file, saveAs = None):
|
def import_voices(files, saveAs=None, progress=None):
|
||||||
global args
|
global args
|
||||||
|
|
||||||
|
if not isinstance(files, list):
|
||||||
|
files = [files]
|
||||||
|
|
||||||
|
for file in enumerate_progress(files, desc="Importing voice files", progress=progress):
|
||||||
j, latents = read_generate_settings(file, read_latents=True)
|
j, latents = read_generate_settings(file, read_latents=True)
|
||||||
|
|
||||||
if j is not None and saveAs is None:
|
if j is not None and saveAs is None:
|
||||||
|
@ -552,7 +582,9 @@ def import_voice(file, saveAs = None):
|
||||||
|
|
||||||
outdir = f'{get_voice_dir()}/{saveAs}/'
|
outdir = f'{get_voice_dir()}/{saveAs}/'
|
||||||
os.makedirs(outdir, exist_ok=True)
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
if latents:
|
if latents:
|
||||||
|
print(f"Importing latents to {latents}")
|
||||||
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
with open(f'{outdir}/cond_latents.pth', 'wb') as f:
|
||||||
f.write(latents)
|
f.write(latents)
|
||||||
latents = f'{outdir}/cond_latents.pth'
|
latents = f'{outdir}/cond_latents.pth'
|
||||||
|
@ -563,6 +595,8 @@ def import_voice(file, saveAs = None):
|
||||||
raise Exception("Please convert to a WAV first")
|
raise Exception("Please convert to a WAV first")
|
||||||
|
|
||||||
path = f"{outdir}/{os.path.basename(filename)}"
|
path = f"{outdir}/{os.path.basename(filename)}"
|
||||||
|
print(f"Importing voice to {path}")
|
||||||
|
|
||||||
waveform, sampling_rate = torchaudio.load(filename)
|
waveform, sampling_rate = torchaudio.load(filename)
|
||||||
|
|
||||||
if args.voice_fixer and voicefixer is not None:
|
if args.voice_fixer and voicefixer is not None:
|
||||||
|
@ -592,10 +626,8 @@ def import_voice(file, saveAs = None):
|
||||||
else:
|
else:
|
||||||
torchaudio.save(path, waveform, sampling_rate)
|
torchaudio.save(path, waveform, sampling_rate)
|
||||||
|
|
||||||
|
|
||||||
print(f"Imported voice to {path}")
|
print(f"Imported voice to {path}")
|
||||||
|
|
||||||
|
|
||||||
def import_generate_settings(file="./config/generate.json"):
|
def import_generate_settings(file="./config/generate.json"):
|
||||||
settings, _ = read_generate_settings(file, read_latents=False)
|
settings, _ = read_generate_settings(file, read_latents=False)
|
||||||
|
|
||||||
|
@ -760,3 +792,20 @@ def read_generate_settings(file, read_latents=True, read_json=True):
|
||||||
j,
|
j,
|
||||||
latents,
|
latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def enumerate_progress(iterable, desc=None, progress=None, verbose=None):
|
||||||
|
if verbose and desc is not None:
|
||||||
|
print(desc)
|
||||||
|
|
||||||
|
if progress is None:
|
||||||
|
return tqdm(iterable, disable=not verbose)
|
||||||
|
return progress.tqdm(iterable, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc, track_tqdm=True)
|
||||||
|
|
||||||
|
def notify_progress(message, progress=None, verbose=True):
|
||||||
|
if verbose:
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
if progress is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
progress(0, desc=message)
|
291
src/webui.py
291
src/webui.py
|
@ -135,6 +135,21 @@ def get_training_configs():
|
||||||
def update_training_configs():
|
def update_training_configs():
|
||||||
return gr.update(choices=get_training_configs())
|
return gr.update(choices=get_training_configs())
|
||||||
|
|
||||||
|
history_headers = {
|
||||||
|
"Name": "",
|
||||||
|
"Samples": "num_autoregressive_samples",
|
||||||
|
"Iterations": "diffusion_iterations",
|
||||||
|
"Temp.": "temperature",
|
||||||
|
"Sampler": "diffusion_sampler",
|
||||||
|
"CVVP": "cvvp_weight",
|
||||||
|
"Top P": "top_p",
|
||||||
|
"Diff. Temp.": "diffusion_temperature",
|
||||||
|
"Len Pen": "length_penalty",
|
||||||
|
"Rep Pen": "repetition_penalty",
|
||||||
|
"Cond-Free K": "cond_free_k",
|
||||||
|
"Time": "time",
|
||||||
|
}
|
||||||
|
|
||||||
def history_view_results( voice ):
|
def history_view_results( voice ):
|
||||||
results = []
|
results = []
|
||||||
files = []
|
files = []
|
||||||
|
@ -148,7 +163,7 @@ def history_view_results( voice ):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
values = []
|
values = []
|
||||||
for k in headers:
|
for k in history_headers:
|
||||||
v = file
|
v = file
|
||||||
if k != "Name":
|
if k != "Name":
|
||||||
v = metadata[headers[k]]
|
v = metadata[headers[k]]
|
||||||
|
@ -163,6 +178,10 @@ def history_view_results( voice ):
|
||||||
gr.Dropdown.update(choices=sorted(files))
|
gr.Dropdown.update(choices=sorted(files))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def import_voices_proxy(files, name, progress=gr.Progress(track_tqdm=True)):
|
||||||
|
import_voices(files, name, progress)
|
||||||
|
return gr.update()
|
||||||
|
|
||||||
def read_generate_settings_proxy(file, saveAs='.temp'):
|
def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||||
j, latents = read_generate_settings(file)
|
j, latents = read_generate_settings(file)
|
||||||
|
|
||||||
|
@ -175,13 +194,14 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||||
latents = f'{outdir}/cond_latents.pth'
|
latents = f'{outdir}/cond_latents.pth'
|
||||||
|
|
||||||
return (
|
return (
|
||||||
j,
|
gr.update(value=j, visible=j is not None),
|
||||||
|
gr.update(visible=j is not None),
|
||||||
gr.update(value=latents, visible=latents is not None),
|
gr.update(value=latents, visible=latents is not None),
|
||||||
None if j is None else j['voice']
|
None if j is None else j['voice']
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_dataset_proxy( voice, language ):
|
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
||||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language )
|
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
||||||
|
|
||||||
def update_voices():
|
def update_voices():
|
||||||
return (
|
return (
|
||||||
|
@ -222,52 +242,18 @@ def setup_gradio():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
|
delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
|
||||||
|
|
||||||
emotion = gr.Radio(
|
emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
|
||||||
["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"],
|
|
||||||
value="Custom",
|
|
||||||
label="Emotion",
|
|
||||||
type="value",
|
|
||||||
interactive=True
|
|
||||||
)
|
|
||||||
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
|
prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
|
||||||
voice = gr.Dropdown(
|
voice = gr.Dropdown(get_voice_list(), label="Voice", type="value")
|
||||||
get_voice_list(),
|
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
|
||||||
label="Voice",
|
|
||||||
type="value",
|
|
||||||
)
|
|
||||||
mic_audio = gr.Audio(
|
|
||||||
label="Microphone Source",
|
|
||||||
source="microphone",
|
|
||||||
type="filepath",
|
|
||||||
)
|
|
||||||
refresh_voices = gr.Button(value="Refresh Voice List")
|
refresh_voices = gr.Button(value="Refresh Voice List")
|
||||||
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
|
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
|
||||||
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
|
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
|
||||||
recompute_voice_latents.click(compute_latents,
|
|
||||||
inputs=[
|
|
||||||
voice,
|
|
||||||
voice_latents_chunks,
|
|
||||||
],
|
|
||||||
outputs=voice,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt.change(fn=lambda value: gr.update(value="Custom"),
|
|
||||||
inputs=prompt,
|
|
||||||
outputs=emotion
|
|
||||||
)
|
|
||||||
mic_audio.change(fn=lambda value: gr.update(value="microphone"),
|
|
||||||
inputs=mic_audio,
|
|
||||||
outputs=voice
|
|
||||||
)
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
candidates = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
|
candidates = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
|
||||||
seed = gr.Number(value=0, precision=0, label="Seed")
|
seed = gr.Number(value=0, precision=0, label="Seed")
|
||||||
|
|
||||||
preset = gr.Radio(
|
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" )
|
||||||
["Ultra Fast", "Fast", "Standard", "High Quality"],
|
|
||||||
label="Preset",
|
|
||||||
type="value",
|
|
||||||
)
|
|
||||||
num_autoregressive_samples = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Samples")
|
num_autoregressive_samples = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Samples")
|
||||||
diffusion_iterations = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations")
|
diffusion_iterations = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations")
|
||||||
|
|
||||||
|
@ -275,19 +261,7 @@ def setup_gradio():
|
||||||
breathing_room = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size")
|
breathing_room = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size")
|
||||||
diffusion_sampler = gr.Radio(
|
diffusion_sampler = gr.Radio(
|
||||||
["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"],
|
["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"],
|
||||||
value="P",
|
value="P", label="Diffusion Samplers", type="value" )
|
||||||
label="Diffusion Samplers",
|
|
||||||
type="value",
|
|
||||||
)
|
|
||||||
|
|
||||||
preset.change(fn=update_presets,
|
|
||||||
inputs=preset,
|
|
||||||
outputs=[
|
|
||||||
num_autoregressive_samples,
|
|
||||||
diffusion_iterations,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings")
|
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings")
|
||||||
reset_generation_settings_button = gr.Button(value="Reset to Default")
|
reset_generation_settings_button = gr.Button(value="Reset to Default")
|
||||||
with gr.Column(visible=False) as col:
|
with gr.Column(visible=False) as col:
|
||||||
|
@ -300,12 +274,6 @@ def setup_gradio():
|
||||||
length_penalty = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty")
|
length_penalty = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty")
|
||||||
repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty")
|
repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty")
|
||||||
cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K")
|
cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K")
|
||||||
|
|
||||||
show_experimental_settings.change(
|
|
||||||
fn=lambda x: gr.update(visible=x),
|
|
||||||
inputs=show_experimental_settings,
|
|
||||||
outputs=experimental_column
|
|
||||||
)
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
submit = gr.Button(value="Generate")
|
submit = gr.Button(value="Generate")
|
||||||
stop = gr.Button(value="Stop")
|
stop = gr.Button(value="Stop")
|
||||||
|
@ -315,33 +283,13 @@ def setup_gradio():
|
||||||
output_audio = gr.Audio(label="Output")
|
output_audio = gr.Audio(label="Output")
|
||||||
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False)
|
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False)
|
||||||
output_pick = gr.Button(value="Select Candidate", visible=False)
|
output_pick = gr.Button(value="Select Candidate", visible=False)
|
||||||
|
|
||||||
with gr.Tab("History"):
|
with gr.Tab("History"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
headers = {
|
history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys()))
|
||||||
"Name": "",
|
|
||||||
"Samples": "num_autoregressive_samples",
|
|
||||||
"Iterations": "diffusion_iterations",
|
|
||||||
"Temp.": "temperature",
|
|
||||||
"Sampler": "diffusion_sampler",
|
|
||||||
"CVVP": "cvvp_weight",
|
|
||||||
"Top P": "top_p",
|
|
||||||
"Diff. Temp.": "diffusion_temperature",
|
|
||||||
"Len Pen": "length_penalty",
|
|
||||||
"Rep Pen": "repetition_penalty",
|
|
||||||
"Cond-Free K": "cond_free_k",
|
|
||||||
"Time": "time",
|
|
||||||
}
|
|
||||||
history_info = gr.Dataframe(label="Results", headers=list(headers.keys()))
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
history_voices = gr.Dropdown(
|
history_voices = gr.Dropdown(choices=get_voice_list("./results/"), label="Voice", type="value")
|
||||||
get_voice_list("./results/"),
|
|
||||||
label="Voice",
|
|
||||||
type="value",
|
|
||||||
)
|
|
||||||
|
|
||||||
history_view_results_button = gr.Button(value="View Files")
|
history_view_results_button = gr.Button(value="View Files")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True)
|
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True)
|
||||||
|
@ -349,51 +297,16 @@ def setup_gradio():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
history_audio = gr.Audio()
|
history_audio = gr.Audio()
|
||||||
history_copy_settings_button = gr.Button(value="Copy Settings")
|
history_copy_settings_button = gr.Button(value="Copy Settings")
|
||||||
|
|
||||||
history_view_results_button.click(
|
|
||||||
fn=history_view_results,
|
|
||||||
inputs=history_voices,
|
|
||||||
outputs=[
|
|
||||||
history_info,
|
|
||||||
history_results_list,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
history_view_result_button.click(
|
|
||||||
fn=lambda voice, file: f"./results/{voice}/{file}",
|
|
||||||
inputs=[
|
|
||||||
history_voices,
|
|
||||||
history_results_list,
|
|
||||||
],
|
|
||||||
outputs=history_audio
|
|
||||||
)
|
|
||||||
with gr.Tab("Utilities"):
|
with gr.Tab("Utilities"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
audio_in = gr.File(type="file", label="Audio Input", file_types=["audio"])
|
audio_in = gr.Files(type="file", label="Audio Input", file_types=["audio"])
|
||||||
copy_button = gr.Button(value="Copy Settings")
|
|
||||||
import_voice_name = gr.Textbox(label="Voice Name")
|
import_voice_name = gr.Textbox(label="Voice Name")
|
||||||
import_voice_button = gr.Button(value="Import Voice")
|
import_voice_button = gr.Button(value="Import Voice")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
metadata_out = gr.JSON(label="Audio Metadata")
|
metadata_out = gr.JSON(label="Audio Metadata", visible=False)
|
||||||
latents_out = gr.File(type="binary", label="Voice Latents")
|
copy_button = gr.Button(value="Copy Settings", visible=False)
|
||||||
|
latents_out = gr.File(type="binary", label="Voice Latents", visible=False)
|
||||||
audio_in.upload(
|
|
||||||
fn=read_generate_settings_proxy,
|
|
||||||
inputs=audio_in,
|
|
||||||
outputs=[
|
|
||||||
metadata_out,
|
|
||||||
latents_out,
|
|
||||||
import_voice_name
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
import_voice_button.click(
|
|
||||||
fn=import_voice,
|
|
||||||
inputs=[
|
|
||||||
audio_in,
|
|
||||||
import_voice_name,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
with gr.Tab("Training"):
|
with gr.Tab("Training"):
|
||||||
with gr.Tab("Prepare Dataset"):
|
with gr.Tab("Prepare Dataset"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -402,16 +315,9 @@ def setup_gradio():
|
||||||
gr.Dropdown( get_voice_list(), label="Dataset Source", type="value" ),
|
gr.Dropdown( get_voice_list(), label="Dataset Source", type="value" ),
|
||||||
gr.Textbox(label="Language", placeholder="English")
|
gr.Textbox(label="Language", placeholder="English")
|
||||||
]
|
]
|
||||||
dataset_voices = dataset_settings[0]
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
prepare_dataset_button = gr.Button(value="Prepare")
|
prepare_dataset_button = gr.Button(value="Prepare")
|
||||||
|
with gr.Column():
|
||||||
prepare_dataset_button.click(
|
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
prepare_dataset_proxy,
|
|
||||||
inputs=dataset_settings,
|
|
||||||
outputs=None
|
|
||||||
)
|
|
||||||
with gr.Tab("Generate Configuration"):
|
with gr.Tab("Generate Configuration"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
@ -421,8 +327,6 @@ def setup_gradio():
|
||||||
gr.Number(label="Print Frequency", value=50),
|
gr.Number(label="Print Frequency", value=50),
|
||||||
gr.Number(label="Save Frequency", value=50),
|
gr.Number(label="Save Frequency", value=50),
|
||||||
]
|
]
|
||||||
save_yaml_button = gr.Button(value="Save Training Configuration")
|
|
||||||
with gr.Column():
|
|
||||||
training_settings = training_settings + [
|
training_settings = training_settings + [
|
||||||
gr.Textbox(label="Training Name", placeholder="finetune"),
|
gr.Textbox(label="Training Name", placeholder="finetune"),
|
||||||
gr.Textbox(label="Dataset Name", placeholder="finetune"),
|
gr.Textbox(label="Dataset Name", placeholder="finetune"),
|
||||||
|
@ -430,24 +334,18 @@ def setup_gradio():
|
||||||
gr.Textbox(label="Validation Name", placeholder="finetune"),
|
gr.Textbox(label="Validation Name", placeholder="finetune"),
|
||||||
gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"),
|
gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"),
|
||||||
]
|
]
|
||||||
|
with gr.Column():
|
||||||
save_yaml_button.click(save_training_settings,
|
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
inputs=training_settings,
|
save_yaml_button = gr.Button(value="Save Training Configuration")
|
||||||
outputs=None
|
with gr.Tab("Run Training"):
|
||||||
)
|
|
||||||
with gr.Tab("Train"):
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
|
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
|
||||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||||
train = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
|
stop_training_button = gr.Button(value="Stop")
|
||||||
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
|
with gr.Column():
|
||||||
train.click(run_training,
|
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
inputs=training_configs,
|
|
||||||
outputs=None
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Tab("Settings"):
|
with gr.Tab("Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
exec_inputs = []
|
exec_inputs = []
|
||||||
|
@ -465,23 +363,22 @@ def setup_gradio():
|
||||||
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
|
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
|
||||||
gr.Checkbox(label="Defer TTS Load", value=args.defer_tts_load),
|
gr.Checkbox(label="Defer TTS Load", value=args.defer_tts_load),
|
||||||
gr.Textbox(label="Device Override", value=args.device_override),
|
gr.Textbox(label="Device Override", value=args.device_override),
|
||||||
gr.Dropdown(label="Whisper Model", value=args.whisper_model, choices=["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"]),
|
|
||||||
]
|
]
|
||||||
gr.Button(value="Check for Updates").click(check_for_updates)
|
|
||||||
gr.Button(value="Reload TTS").click(reload_tts)
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
exec_inputs = exec_inputs + [
|
exec_inputs = exec_inputs + [
|
||||||
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
|
gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
|
||||||
gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count),
|
gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count),
|
||||||
gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate),
|
gr.Number(label="Ouptut Sample Rate", precision=0, value=args.output_sample_rate),
|
||||||
gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume),
|
gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume),
|
||||||
|
gr.Dropdown(label="Whisper Model", value=args.whisper_model, choices=["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"]),
|
||||||
]
|
]
|
||||||
|
gr.Button(value="Check for Updates").click(check_for_updates)
|
||||||
|
gr.Button(value="Reload TTS").click(reload_tts)
|
||||||
|
|
||||||
for i in exec_inputs:
|
for i in exec_inputs:
|
||||||
i.change(
|
i.change( fn=export_exec_settings, inputs=exec_inputs )
|
||||||
fn=export_exec_settings,
|
|
||||||
inputs=exec_inputs
|
# console_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
)
|
|
||||||
|
|
||||||
input_settings = [
|
input_settings = [
|
||||||
text,
|
text,
|
||||||
|
@ -507,11 +404,76 @@ def setup_gradio():
|
||||||
experimental_checkboxes,
|
experimental_checkboxes,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
history_view_results_button.click(
|
||||||
|
fn=history_view_results,
|
||||||
|
inputs=history_voices,
|
||||||
|
outputs=[
|
||||||
|
history_info,
|
||||||
|
history_results_list,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
history_view_result_button.click(
|
||||||
|
fn=lambda voice, file: f"./results/{voice}/{file}",
|
||||||
|
inputs=[
|
||||||
|
history_voices,
|
||||||
|
history_results_list,
|
||||||
|
],
|
||||||
|
outputs=history_audio
|
||||||
|
)
|
||||||
|
audio_in.upload(
|
||||||
|
fn=read_generate_settings_proxy,
|
||||||
|
inputs=audio_in,
|
||||||
|
outputs=[
|
||||||
|
metadata_out,
|
||||||
|
copy_button,
|
||||||
|
latents_out,
|
||||||
|
import_voice_name
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
import_voice_button.click(
|
||||||
|
fn=import_voices_proxy,
|
||||||
|
inputs=[
|
||||||
|
audio_in,
|
||||||
|
import_voice_name,
|
||||||
|
],
|
||||||
|
outputs=import_voice_name #console_output
|
||||||
|
)
|
||||||
|
show_experimental_settings.change(
|
||||||
|
fn=lambda x: gr.update(visible=x),
|
||||||
|
inputs=show_experimental_settings,
|
||||||
|
outputs=experimental_column
|
||||||
|
)
|
||||||
|
preset.change(fn=update_presets,
|
||||||
|
inputs=preset,
|
||||||
|
outputs=[
|
||||||
|
num_autoregressive_samples,
|
||||||
|
diffusion_iterations,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
recompute_voice_latents.click(compute_latents,
|
||||||
|
inputs=[
|
||||||
|
voice,
|
||||||
|
voice_latents_chunks,
|
||||||
|
],
|
||||||
|
outputs=voice,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt.change(fn=lambda value: gr.update(value="Custom"),
|
||||||
|
inputs=prompt,
|
||||||
|
outputs=emotion
|
||||||
|
)
|
||||||
|
mic_audio.change(fn=lambda value: gr.update(value="microphone"),
|
||||||
|
inputs=mic_audio,
|
||||||
|
outputs=voice
|
||||||
|
)
|
||||||
|
|
||||||
refresh_voices.click(update_voices,
|
refresh_voices.click(update_voices,
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=[
|
outputs=[
|
||||||
voice,
|
voice,
|
||||||
dataset_voices,
|
dataset_settings[0],
|
||||||
history_voices
|
history_voices
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -552,6 +514,25 @@ def setup_gradio():
|
||||||
outputs=input_settings
|
outputs=input_settings
|
||||||
)
|
)
|
||||||
|
|
||||||
|
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
|
||||||
|
start_training_button.click(run_training,
|
||||||
|
inputs=training_configs,
|
||||||
|
outputs=training_output #console_output
|
||||||
|
)
|
||||||
|
stop_training_button.click(stop_training,
|
||||||
|
inputs=None,
|
||||||
|
outputs=training_output #console_output
|
||||||
|
)
|
||||||
|
prepare_dataset_button.click(
|
||||||
|
prepare_dataset_proxy,
|
||||||
|
inputs=dataset_settings,
|
||||||
|
outputs=prepare_dataset_output #console_output
|
||||||
|
)
|
||||||
|
save_yaml_button.click(save_training_settings,
|
||||||
|
inputs=training_settings,
|
||||||
|
outputs=save_yaml_output #console_output
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.isfile('./config/generate.json'):
|
if os.path.isfile('./config/generate.json'):
|
||||||
ui.load(import_generate_settings, inputs=None, outputs=input_settings)
|
ui.load(import_generate_settings, inputs=None, outputs=input_settings)
|
||||||
|
|
||||||
|
|
4
train.bat
Executable file
4
train.bat
Executable file
|
@ -0,0 +1,4 @@
|
||||||
|
call .\venv\Scripts\activate.bat
|
||||||
|
python ./src/train.py -opt "%1"
|
||||||
|
deactivate
|
||||||
|
pause
|
3
train.sh
Executable file
3
train.sh
Executable file
|
@ -0,0 +1,3 @@
|
||||||
|
source ./venv/bin/activate
|
||||||
|
python3 ./src/train.py -opt "$1"
|
||||||
|
deactivate
|
0
training/.gitkeep
Executable file
0
training/.gitkeep
Executable file
Loading…
Reference in New Issue
Block a user