sloppy fix to actually kill children when using multi-GPU distributed training, set GPU training count based on what CUDA exposes automatically so I don't have to keep setting it to 2

This commit is contained in:
mrq 2023-03-04 20:42:54 +00:00
parent 1a9d159b2a
commit 5026d93ecd
2 changed files with 21 additions and 4 deletions

View File

@ -17,6 +17,7 @@ import urllib.request
import signal import signal
import gc import gc
import subprocess import subprocess
import psutil
import yaml import yaml
import tqdm import tqdm
@ -556,7 +557,7 @@ class TrainingState():
self.spawn_process(config_path=config_path, gpus=gpus) self.spawn_process(config_path=config_path, gpus=gpus)
def spawn_process(self, config_path, gpus=1): def spawn_process(self, config_path, gpus=1):
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path] self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', str(int(gpus)), config_path]
print("Spawning process: ", " ".join(self.cmd)) print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
@ -815,6 +816,9 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus) training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
for line in iter(training_state.process.stdout.readline, ""): for line in iter(training_state.process.stdout.readline, ""):
if training_state.killed:
return
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress ) result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
if result: if result:
@ -868,10 +872,22 @@ def stop_training():
return "No training in progress" return "No training in progress"
print("Killing training process...") print("Killing training process...")
training_state.killed = True training_state.killed = True
children = []
# wrapped in a try/catch in case for some reason this fails outside of Linux
try:
children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
except Exception as e:
pass
training_state.process.stdout.close() training_state.process.stdout.close()
#training_state.process.terminate() training_state.process.terminate()
training_state.process.send_signal(signal.SIGINT) training_state.process.kill()
return_code = training_state.process.wait() return_code = training_state.process.wait()
for p in children:
os.kill( p['pid'], signal.SIGKILL )
training_state = None training_state = None
print("Killed training process.") print("Killed training process.")
return f"Training cancelled: {return_code}" return f"Training cancelled: {return_code}"

View File

@ -16,6 +16,7 @@ from datetime import datetime
import tortoise.api import tortoise.api
from tortoise.utils.audio import get_voice_dir, get_voices from tortoise.utils.audio import get_voice_dir, get_voices
from tortoise.utils.device import get_device_count
from utils import * from utils import *
@ -536,7 +537,7 @@ def setup_gradio():
with gr.Row(): with gr.Row():
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
training_gpu_count = gr.Number(label="GPUs", value=1) training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
with gr.Row(): with gr.Row():
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop") stop_training_button = gr.Button(value="Stop")