diff --git a/src/utils.py b/src/utils.py index 5e49a8e..1f6a707 100755 --- a/src/utils.py +++ b/src/utils.py @@ -17,6 +17,7 @@ import urllib.request import signal import gc import subprocess +import psutil import yaml import tqdm @@ -556,7 +557,7 @@ class TrainingState(): self.spawn_process(config_path=config_path, gpus=gpus) def spawn_process(self, config_path, gpus=1): - self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path] + self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', str(int(gpus)), config_path] print("Spawning process: ", " ".join(self.cmd)) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) @@ -815,6 +816,9 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus) for line in iter(training_state.process.stdout.readline, ""): + if training_state.killed: + return + result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress ) print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") if result: @@ -868,10 +872,22 @@ def stop_training(): return "No training in progress" print("Killing training process...") training_state.killed = True + + children = [] + # wrapped in a try/catch in case for some reason this fails outside of Linux + try: + children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']] + except Exception as e: + pass + training_state.process.stdout.close() - #training_state.process.terminate() - training_state.process.send_signal(signal.SIGINT) + training_state.process.terminate() + training_state.process.kill() return_code = training_state.process.wait() + + for p in children: + os.kill( p['pid'], signal.SIGKILL ) + training_state = None print("Killed training process.") return f"Training cancelled: {return_code}" diff --git a/src/webui.py b/src/webui.py index 6bf8224..89d9462 100755 --- a/src/webui.py +++ b/src/webui.py @@ -16,6 +16,7 @@ from datetime import datetime import tortoise.api from tortoise.utils.audio import get_voice_dir, get_voices +from tortoise.utils.device import get_device_count from utils import * @@ -536,7 +537,7 @@ def setup_gradio(): with gr.Row(): training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) - training_gpu_count = gr.Number(label="GPUs", value=1) + training_gpu_count = gr.Number(label="GPUs", value=get_device_count()) with gr.Row(): start_training_button = gr.Button(value="Train") stop_training_button = gr.Button(value="Stop")