forked from mrq/ai-voice-cloning
sloppy fix to actually kill children when using multi-GPU distributed training, set GPU training count based on what CUDA exposes automatically so I don't have to keep setting it to 2
This commit is contained in:
parent
1a9d159b2a
commit
5026d93ecd
22
src/utils.py
22
src/utils.py
|
@ -17,6 +17,7 @@ import urllib.request
|
||||||
import signal
|
import signal
|
||||||
import gc
|
import gc
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import psutil
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -556,7 +557,7 @@ class TrainingState():
|
||||||
self.spawn_process(config_path=config_path, gpus=gpus)
|
self.spawn_process(config_path=config_path, gpus=gpus)
|
||||||
|
|
||||||
def spawn_process(self, config_path, gpus=1):
|
def spawn_process(self, config_path, gpus=1):
|
||||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path]
|
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', str(int(gpus)), config_path]
|
||||||
|
|
||||||
print("Spawning process: ", " ".join(self.cmd))
|
print("Spawning process: ", " ".join(self.cmd))
|
||||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
|
@ -815,6 +816,9 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
|
||||||
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
||||||
|
|
||||||
for line in iter(training_state.process.stdout.readline, ""):
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
|
if training_state.killed:
|
||||||
|
return
|
||||||
|
|
||||||
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
||||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||||
if result:
|
if result:
|
||||||
|
@ -868,10 +872,22 @@ def stop_training():
|
||||||
return "No training in progress"
|
return "No training in progress"
|
||||||
print("Killing training process...")
|
print("Killing training process...")
|
||||||
training_state.killed = True
|
training_state.killed = True
|
||||||
|
|
||||||
|
children = []
|
||||||
|
# wrapped in a try/catch in case for some reason this fails outside of Linux
|
||||||
|
try:
|
||||||
|
children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
training_state.process.stdout.close()
|
training_state.process.stdout.close()
|
||||||
#training_state.process.terminate()
|
training_state.process.terminate()
|
||||||
training_state.process.send_signal(signal.SIGINT)
|
training_state.process.kill()
|
||||||
return_code = training_state.process.wait()
|
return_code = training_state.process.wait()
|
||||||
|
|
||||||
|
for p in children:
|
||||||
|
os.kill( p['pid'], signal.SIGKILL )
|
||||||
|
|
||||||
training_state = None
|
training_state = None
|
||||||
print("Killed training process.")
|
print("Killed training process.")
|
||||||
return f"Training cancelled: {return_code}"
|
return f"Training cancelled: {return_code}"
|
||||||
|
|
|
@ -16,6 +16,7 @@ from datetime import datetime
|
||||||
|
|
||||||
import tortoise.api
|
import tortoise.api
|
||||||
from tortoise.utils.audio import get_voice_dir, get_voices
|
from tortoise.utils.audio import get_voice_dir, get_voices
|
||||||
|
from tortoise.utils.device import get_device_count
|
||||||
|
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
|
@ -536,7 +537,7 @@ def setup_gradio():
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||||
training_gpu_count = gr.Number(label="GPUs", value=1)
|
training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_training_button = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
stop_training_button = gr.Button(value="Stop")
|
stop_training_button = gr.Button(value="Stop")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user