updated notebooks to use the new "main" setup
This commit is contained in:
parent
f8249aa826
commit
9c0e4666d2
|
@ -71,16 +71,29 @@
|
||||||
{
|
{
|
||||||
"cell_type":"code",
|
"cell_type":"code",
|
||||||
"source":[
|
"source":[
|
||||||
"%cd ai-voice-cloning\n",
|
"%cd /content/ai-voice-cloning\n",
|
||||||
"import src.webui as mrq\n",
|
|
||||||
"import sys\n",
|
|
||||||
"sys.argv = [\"\"]\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"mrq.args = mrq.setup_args()\n",
|
"import os\n",
|
||||||
"mrq.webui = mrq.setup_gradio()\n",
|
"import sys\n",
|
||||||
"mrq.webui.launch(share=True, prevent_thread_lock=True, height=1000)\n",
|
"\n",
|
||||||
"mrq.tts = mrq.setup_tortoise()\n",
|
"sys.argv = [\"\"]\n",
|
||||||
"mrq.webui.block_thread()"
|
"sys.path.append('./src/')\n",
|
||||||
|
"\n",
|
||||||
|
"if 'TORTOISE_MODELS_DIR' not in os.environ:\n",
|
||||||
|
"\tos.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/'))\n",
|
||||||
|
"\n",
|
||||||
|
"if 'TRANSFORMERS_CACHE' not in os.environ:\n",
|
||||||
|
"\tos.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))\n",
|
||||||
|
"\n",
|
||||||
|
"from utils import *\n",
|
||||||
|
"from webui import *\n",
|
||||||
|
"\n",
|
||||||
|
"args = setup_args()\n",
|
||||||
|
"\n",
|
||||||
|
"webui = setup_gradio()\n",
|
||||||
|
"tts = setup_tortoise()\n",
|
||||||
|
"webui.launch(share=True, prevent_thread_lock=True, height=1000)\n",
|
||||||
|
"webui.block_thread()"
|
||||||
],
|
],
|
||||||
"metadata":{
|
"metadata":{
|
||||||
"id":"c_EQZLTA19c7"
|
"id":"c_EQZLTA19c7"
|
||||||
|
@ -102,6 +115,7 @@
|
||||||
{
|
{
|
||||||
"cell_type":"code",
|
"cell_type":"code",
|
||||||
"source":[
|
"source":[
|
||||||
|
"%cd /content/ai-voice-cloning\n",
|
||||||
"!apt install -y p7zip-full\n",
|
"!apt install -y p7zip-full\n",
|
||||||
"from datetime import datetime\n",
|
"from datetime import datetime\n",
|
||||||
"timestamp = datetime.now().strftime('%m-%d-%Y_%H:%M:%S')\n",
|
"timestamp = datetime.now().strftime('%m-%d-%Y_%H:%M:%S')\n",
|
||||||
|
|
35
src/train.py
35
src/train.py
|
@ -1,35 +0,0 @@
|
||||||
import torch
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from ..dlas.codes import *
|
|
||||||
from ..dlas.codes.utils import util, options as option
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
|
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
|
||||||
args = parser.parse_args()
|
|
||||||
opt = option.parse(args.opt, is_train=True)
|
|
||||||
if args.launcher != 'none':
|
|
||||||
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
|
||||||
if 'gpu_ids' in opt.keys():
|
|
||||||
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
|
||||||
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
|
||||||
trainer = Trainer()
|
|
||||||
|
|
||||||
#### distributed training settings
|
|
||||||
if args.launcher == 'none': # disabled distributed training
|
|
||||||
opt['dist'] = False
|
|
||||||
trainer.rank = -1
|
|
||||||
if len(opt['gpu_ids']) == 1:
|
|
||||||
torch.cuda.set_device(opt['gpu_ids'][0])
|
|
||||||
print('Disabled distributed training.')
|
|
||||||
else:
|
|
||||||
opt['dist'] = True
|
|
||||||
init_dist('nccl')
|
|
||||||
trainer.world_size = torch.distributed.get_world_size()
|
|
||||||
trainer.rank = torch.distributed.get_rank()
|
|
||||||
torch.cuda.set_device(torch.distributed.get_rank())
|
|
||||||
|
|
||||||
trainer.init(args.opt, opt, args.launcher)
|
|
||||||
trainer.do_training()
|
|
|
@ -450,7 +450,6 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
|
||||||
yaml = f.read()
|
yaml = f.read()
|
||||||
|
|
||||||
for k in settings:
|
for k in settings:
|
||||||
print(f"${{{k}}} => {settings[k]}")
|
|
||||||
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
|
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
|
||||||
|
|
||||||
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
|
with open(f'./training/{settings["name"]}.yaml', 'w', encoding="utf-8") as f:
|
||||||
|
|
69
train.ipynb
Executable file
69
train.ipynb
Executable file
|
@ -0,0 +1,69 @@
|
||||||
|
{
|
||||||
|
"nbformat":4,
|
||||||
|
"nbformat_minor":0,
|
||||||
|
"metadata":{
|
||||||
|
"colab":{
|
||||||
|
"private_outputs":true,
|
||||||
|
"provenance":[
|
||||||
|
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"kernelspec":{
|
||||||
|
"name":"python3",
|
||||||
|
"display_name":"Python 3"
|
||||||
|
},
|
||||||
|
"language_info":{
|
||||||
|
"name":"python"
|
||||||
|
},
|
||||||
|
"accelerator":"GPU",
|
||||||
|
"gpuClass":"standard"
|
||||||
|
},
|
||||||
|
"cells":[
|
||||||
|
{
|
||||||
|
"cell_type":"code",
|
||||||
|
"execution_count":null,
|
||||||
|
"metadata":{
|
||||||
|
"id":"AaKpV3rCI3Eo"
|
||||||
|
},
|
||||||
|
"outputs":[
|
||||||
|
|
||||||
|
],
|
||||||
|
"source":[
|
||||||
|
"!git clone https://git.ecker.tech/mrq/DL-Art-School\n",
|
||||||
|
"%cd DL-Art-School\n",
|
||||||
|
"!pip install -r requirements.txt"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type":"code",
|
||||||
|
"source":[
|
||||||
|
"from google.colab import drive\n",
|
||||||
|
"drive.mount('/content/drive')",
|
||||||
|
"%cd /content/DL-Art-School/\n",
|
||||||
|
"#!rm -r experiments\n",
|
||||||
|
"!ln -s /content/drive/MyDrive/experiments/\n",
|
||||||
|
],
|
||||||
|
"metadata":{
|
||||||
|
"id":"8eV92cjGI4XL"
|
||||||
|
},
|
||||||
|
"execution_count":null,
|
||||||
|
"outputs":[
|
||||||
|
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type":"code",
|
||||||
|
"source":[
|
||||||
|
"%cd /content/DL-Art-School/\n",
|
||||||
|
"!python ./codes/train.py -opt ./experiments/ar.yml"
|
||||||
|
],
|
||||||
|
"metadata":{
|
||||||
|
"id":"7lcRGqglX2FC"
|
||||||
|
},
|
||||||
|
"execution_count":null,
|
||||||
|
"outputs":[
|
||||||
|
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user