1
0

updated notebooks to use the new "main" setup

This commit is contained in:
mrq 2023-02-17 03:30:53 +00:00
parent f8249aa826
commit 9c0e4666d2
4 changed files with 92 additions and 45 deletions

View File

@ -71,16 +71,29 @@
{
"cell_type":"code",
"source":[
"%cd ai-voice-cloning\n",
"import src.webui as mrq\n",
"import sys\n",
"sys.argv = [\"\"]\n",
"%cd /content/ai-voice-cloning\n",
"\n",
"mrq.args = mrq.setup_args()\n",
"mrq.webui = mrq.setup_gradio()\n",
"mrq.webui.launch(share=True, prevent_thread_lock=True, height=1000)\n",
"mrq.tts = mrq.setup_tortoise()\n",
"mrq.webui.block_thread()"
"import os\n",
"import sys\n",
"\n",
"sys.argv = [\"\"]\n",
"sys.path.append('./src/')\n",
"\n",
"if 'TORTOISE_MODELS_DIR' not in os.environ:\n",
"\tos.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))\n",
"\n",
"if 'TRANSFORMERS_CACHE' not in os.environ:\n",
"\tos.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))\n",
"\n",
"from utils import *\n",
"from webui import *\n",
"\n",
"args = setup_args()\n",
"\n",
"webui = setup_gradio()\n",
"tts = setup_tortoise()\n",
"webui.launch(share=True, prevent_thread_lock=True, height=1000)\n",
"webui.block_thread()"
],
"metadata":{
"id":"c_EQZLTA19c7"
@ -102,6 +115,7 @@
{
"cell_type":"code",
"source":[
"%cd /content/ai-voice-cloning\n",
"!apt install -y p7zip-full\n",
"from datetime import datetime\n",
"timestamp = datetime.now().strftime('%m-%d-%Y_%H:%M:%S')\n",

View File

@ -1,35 +0,0 @@
import torch
import argparse
from ..dlas.codes import *
from ..dlas.codes.utils import util, options as option
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
if args.launcher != 'none':
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
if 'gpu_ids' in opt.keys():
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
trainer = Trainer()
#### distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(args.opt, opt, args.launcher)
trainer.do_training()

View File

@ -450,7 +450,6 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
yaml = f.read()
for k in settings:
print(f"${{{k}}} => {settings[k]}")
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:

69
train.ipynb Executable file
View File

@ -0,0 +1,69 @@
{
"nbformat":4,
"nbformat_minor":0,
"metadata":{
"colab":{
"private_outputs":true,
"provenance":[
]
},
"kernelspec":{
"name":"python3",
"display_name":"Python 3"
},
"language_info":{
"name":"python"
},
"accelerator":"GPU",
"gpuClass":"standard"
},
"cells":[
{
"cell_type":"code",
"execution_count":null,
"metadata":{
"id":"AaKpV3rCI3Eo"
},
"outputs":[
],
"source":[
"!git clone https://git.ecker.tech/mrq/DL-Art-School\n",
"%cd DL-Art-School\n",
"!pip install -r requirements.txt"
]
},
{
"cell_type":"code",
"source":[
"from google.colab import drive\n",
"drive.mount('/content/drive')",
"%cd /content/DL-Art-School/\n",
"#!rm -r experiments\n",
"!ln -s /content/drive/MyDrive/experiments/\n",
],
"metadata":{
"id":"8eV92cjGI4XL"
},
"execution_count":null,
"outputs":[
]
},
{
"cell_type":"code",
"source":[
"%cd /content/DL-Art-School/\n",
"!python ./codes/train.py -opt ./experiments/ar.yml"
],
"metadata":{
"id":"7lcRGqglX2FC"
},
"execution_count":null,
"outputs":[
]
}
]
}