forked from mrq/ai-voice-cloning
experimental multi-gpu training (Linux only, because I can't into batch files)
This commit is contained in:
parent
e205322c8d
commit
e859a7c01d
7
.gitignore
vendored
7
.gitignore
vendored
|
@ -1,7 +1,8 @@
|
||||||
# ignores user files
|
# ignores user files
|
||||||
/tortoise-venv/
|
/venv/
|
||||||
/tortoise/voices/
|
/voices/*
|
||||||
/models/
|
/models/*
|
||||||
|
/training/*
|
||||||
/config/*
|
/config/*
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
|
|
|
@ -18,9 +18,12 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
|
parser.add_argument('--local_rank', type=int, help='Rank Number')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.opt = " ".join(args.opt) # absolutely disgusting
|
args.opt = " ".join(args.opt) # absolutely disgusting
|
||||||
|
|
||||||
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||||
|
|
||||||
with open(args.opt, 'r') as file:
|
with open(args.opt, 'r') as file:
|
||||||
opt_config = yaml.safe_load(file)
|
opt_config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
@ -71,7 +74,7 @@ def train(yaml, launcher='none'):
|
||||||
print('Disabled distributed training.')
|
print('Disabled distributed training.')
|
||||||
else:
|
else:
|
||||||
opt['dist'] = True
|
opt['dist'] = True
|
||||||
init_dist('nccl')
|
tr.init_dist('nccl')
|
||||||
trainer.world_size = torch.distributed.get_world_size()
|
trainer.world_size = torch.distributed.get_world_size()
|
||||||
trainer.rank = torch.distributed.get_rank()
|
trainer.rank = torch.distributed.get_rank()
|
||||||
torch.cuda.set_device(torch.distributed.get_rank())
|
torch.cuda.set_device(torch.distributed.get_rank())
|
||||||
|
|
29
src/utils.py
29
src/utils.py
|
@ -477,9 +477,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
||||||
|
|
||||||
# superfluous, but it cleans up some things
|
# superfluous, but it cleans up some things
|
||||||
class TrainingState():
|
class TrainingState():
|
||||||
def __init__(self, config_path, keep_x_past_datasets=0, start=True):
|
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
|
||||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
|
||||||
|
|
||||||
# parse config to get its iteration
|
# parse config to get its iteration
|
||||||
with open(config_path, 'r') as file:
|
with open(config_path, 'r') as file:
|
||||||
self.config = yaml.safe_load(file)
|
self.config = yaml.safe_load(file)
|
||||||
|
@ -530,9 +528,11 @@ class TrainingState():
|
||||||
if keep_x_past_datasets > 0:
|
if keep_x_past_datasets > 0:
|
||||||
self.cleanup_old(keep=keep_x_past_datasets)
|
self.cleanup_old(keep=keep_x_past_datasets)
|
||||||
if start:
|
if start:
|
||||||
self.spawn_process()
|
self.spawn_process(config_path=config_path, gpus=gpus)
|
||||||
|
|
||||||
|
def spawn_process(self, config_path, gpus=1):
|
||||||
|
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path]
|
||||||
|
|
||||||
def spawn_process(self):
|
|
||||||
print("Spawning process: ", " ".join(self.cmd))
|
print("Spawning process: ", " ".join(self.cmd))
|
||||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
|
|
||||||
|
@ -745,7 +745,7 @@ class TrainingState():
|
||||||
if should_return:
|
if should_return:
|
||||||
return "".join(self.buffer)
|
return "".join(self.buffer)
|
||||||
|
|
||||||
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
if training_state and training_state.process:
|
if training_state and training_state.process:
|
||||||
return "Training already in progress"
|
return "Training already in progress"
|
||||||
|
@ -757,7 +757,7 @@ def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets
|
||||||
unload_whisper()
|
unload_whisper()
|
||||||
unload_voicefixer()
|
unload_voicefixer()
|
||||||
|
|
||||||
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets)
|
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
||||||
|
|
||||||
for line in iter(training_state.process.stdout.readline, ""):
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
|
|
||||||
|
@ -785,11 +785,13 @@ def update_training_dataplot(config_path=None):
|
||||||
update = None
|
update = None
|
||||||
|
|
||||||
if not training_state:
|
if not training_state:
|
||||||
training_state = TrainingState(config_path=config_path, start=False)
|
if config_path:
|
||||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
training_state = TrainingState(config_path=config_path, start=False)
|
||||||
del training_state
|
if training_state.losses:
|
||||||
training_state = None
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||||
else:
|
del training_state
|
||||||
|
training_state = None
|
||||||
|
elif training_state.losses:
|
||||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||||
|
|
||||||
return update
|
return update
|
||||||
|
@ -1285,12 +1287,13 @@ def setup_args():
|
||||||
if not args.device_override:
|
if not args.device_override:
|
||||||
set_device_name(args.device_override)
|
set_device_name(args.device_override)
|
||||||
|
|
||||||
|
|
||||||
args.listen_host = None
|
args.listen_host = None
|
||||||
args.listen_port = None
|
args.listen_port = None
|
||||||
args.listen_path = None
|
args.listen_path = None
|
||||||
if args.listen:
|
if args.listen:
|
||||||
try:
|
try:
|
||||||
match = re.findall(r"^(?:(.+?):(\d+))?(\/.+?)?$", args.listen)[0]
|
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
|
||||||
|
|
||||||
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
|
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
|
||||||
args.listen_port = match[1] if match[1] != "" else None
|
args.listen_port = match[1] if match[1] != "" else None
|
||||||
|
|
|
@ -546,6 +546,7 @@ def setup_gradio():
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
||||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||||
|
training_gpu_count = gr.Number(label="GPUs", value=1)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_training_button = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
stop_training_button = gr.Button(value="Stop")
|
stop_training_button = gr.Button(value="Stop")
|
||||||
|
@ -751,6 +752,7 @@ def setup_gradio():
|
||||||
inputs=[
|
inputs=[
|
||||||
training_configs,
|
training_configs,
|
||||||
verbose_training,
|
verbose_training,
|
||||||
|
training_gpu_count,
|
||||||
training_buffer_size,
|
training_buffer_size,
|
||||||
training_keep_x_past_datasets,
|
training_keep_x_past_datasets,
|
||||||
],
|
],
|
||||||
|
|
11
train.sh
11
train.sh
|
@ -1,4 +1,13 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
source ./venv/bin/activate
|
source ./venv/bin/activate
|
||||||
python3 ./src/train.py -opt "$1"
|
|
||||||
|
GPUS=$1
|
||||||
|
CONFIG=$2
|
||||||
|
PORT=1234
|
||||||
|
|
||||||
|
if (( $GPUS > 1 )); then
|
||||||
|
python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
|
||||||
|
else
|
||||||
|
python3 ./src/train.py -opt "$CONFIG"
|
||||||
|
fi
|
||||||
deactivate
|
deactivate
|
||||||
|
|
Loading…
Reference in New Issue
Block a user