@ -17,6 +17,7 @@ import urllib.request
import signal
import signal
import gc
import gc
import subprocess
import subprocess
import psutil
import yaml
import yaml
import tqdm
import tqdm
@ -556,7 +557,7 @@ class TrainingState():
self . spawn_process ( config_path = config_path , gpus = gpus )
self . spawn_process ( config_path = config_path , gpus = gpus )
def spawn_process ( self , config_path , gpus = 1 ) :
def spawn_process ( self , config_path , gpus = 1 ) :
self . cmd = [ ' train.bat ' , config_path ] if os . name == " nt " else [ ' bash' , ' ./train.sh' , str ( int ( gpus ) ) , config_path ]
self . cmd = [ ' train.bat ' , config_path ] if os . name == " nt " else [ ' ./train.sh' , str ( int ( gpus ) ) , config_path ]
print ( " Spawning process: " , " " . join ( self . cmd ) )
print ( " Spawning process: " , " " . join ( self . cmd ) )
self . process = subprocess . Popen ( self . cmd , stdout = subprocess . PIPE , stderr = subprocess . STDOUT , universal_newlines = True )
self . process = subprocess . Popen ( self . cmd , stdout = subprocess . PIPE , stderr = subprocess . STDOUT , universal_newlines = True )
@ -815,6 +816,9 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
training_state = TrainingState ( config_path = config_path , keep_x_past_datasets = keep_x_past_datasets , gpus = gpus )
training_state = TrainingState ( config_path = config_path , keep_x_past_datasets = keep_x_past_datasets , gpus = gpus )
for line in iter ( training_state . process . stdout . readline , " " ) :
for line in iter ( training_state . process . stdout . readline , " " ) :
if training_state . killed :
return
result , percent , message = training_state . parse ( line = line , verbose = verbose , keep_x_past_datasets = keep_x_past_datasets , progress = progress )
result , percent , message = training_state . parse ( line = line , verbose = verbose , keep_x_past_datasets = keep_x_past_datasets , progress = progress )
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
if result :
if result :
@ -868,10 +872,22 @@ def stop_training():
return " No training in progress "
return " No training in progress "
print ( " Killing training process... " )
print ( " Killing training process... " )
training_state . killed = True
training_state . killed = True
children = [ ]
# wrapped in a try/catch in case for some reason this fails outside of Linux
try :
children = [ p . info for p in psutil . process_iter ( attrs = [ ' pid ' , ' name ' , ' cmdline ' ] ) if ' ./src/train.py ' in p . info [ ' cmdline ' ] ]
except Exception as e :
pass
training_state . process . stdout . close ( )
training_state . process . stdout . close ( )
#training_state.process.terminate()
training_state . process . terminate ( )
training_state . process . send_signal ( signal . SIGINT )
training_state . process . kill( )
return_code = training_state . process . wait ( )
return_code = training_state . process . wait ( )
for p in children :
os . kill ( p [ ' pid ' ] , signal . SIGKILL )
training_state = None
training_state = None
print ( " Killed training process. " )
print ( " Killed training process. " )
return f " Training cancelled: { return_code } "
return f " Training cancelled: { return_code } "