sloppy fix to actually kill children when using multi-GPU distributed training, set GPU training count based on what CUDA exposes automatically so I don't have to keep setting it to 2
This commit is contained in:
parent
1a9d159b2a
commit
5026d93ecd
22
src/utils.py
22
src/utils.py
|
@ -17,6 +17,7 @@ import urllib.request
|
|||
import signal
|
||||
import gc
|
||||
import subprocess
|
||||
import psutil
|
||||
import yaml
|
||||
|
||||
import tqdm
|
||||
|
@ -556,7 +557,7 @@ class TrainingState():
|
|||
self.spawn_process(config_path=config_path, gpus=gpus)
|
||||
|
||||
def spawn_process(self, config_path, gpus=1):
|
||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path]
|
||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', str(int(gpus)), config_path]
|
||||
|
||||
print("Spawning process: ", " ".join(self.cmd))
|
||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
|
@ -815,6 +816,9 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
|
|||
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
||||
|
||||
for line in iter(training_state.process.stdout.readline, ""):
|
||||
if training_state.killed:
|
||||
return
|
||||
|
||||
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||
if result:
|
||||
|
@ -868,10 +872,22 @@ def stop_training():
|
|||
return "No training in progress"
|
||||
print("Killing training process...")
|
||||
training_state.killed = True
|
||||
|
||||
children = []
|
||||
# wrapped in a try/catch in case for some reason this fails outside of Linux
|
||||
try:
|
||||
children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
training_state.process.stdout.close()
|
||||
#training_state.process.terminate()
|
||||
training_state.process.send_signal(signal.SIGINT)
|
||||
training_state.process.terminate()
|
||||
training_state.process.kill()
|
||||
return_code = training_state.process.wait()
|
||||
|
||||
for p in children:
|
||||
os.kill( p['pid'], signal.SIGKILL )
|
||||
|
||||
training_state = None
|
||||
print("Killed training process.")
|
||||
return f"Training cancelled: {return_code}"
|
||||
|
|
|
@ -16,6 +16,7 @@ from datetime import datetime
|
|||
|
||||
import tortoise.api
|
||||
from tortoise.utils.audio import get_voice_dir, get_voices
|
||||
from tortoise.utils.device import get_device_count
|
||||
|
||||
from utils import *
|
||||
|
||||
|
@ -536,7 +537,7 @@ def setup_gradio():
|
|||
|
||||
with gr.Row():
|
||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
training_gpu_count = gr.Number(label="GPUs", value=1)
|
||||
training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
|
||||
with gr.Row():
|
||||
start_training_button = gr.Button(value="Train")
|
||||
stop_training_button = gr.Button(value="Stop")
|
||||
|
|
Loading…
Reference in New Issue
Block a user