diff --git a/src/utils.py b/src/utils.py index 7e1fe9b..0b8a282 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1048,15 +1048,19 @@ def stop_training(): training_state.killed = True children = [] - # wrapped in a try/catch in case for some reason this fails outside of Linux - try: - children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']] - except Exception as e: - pass + if args.tts_backend == "tortoise": + # wrapped in a try/catch in case for some reason this fails outside of Linux + try: + children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']] + except Exception as e: + pass + + training_state.process.stdout.close() + training_state.process.terminate() + training_state.process.kill() + elif args.tts_backend == "vall-e": + print(training_state.process.communicate(input='quit')[0]) - training_state.process.stdout.close() - training_state.process.terminate() - training_state.process.kill() return_code = training_state.process.wait() for p in children: