forked from mrq/ai-voice-cloning
added some output parsing during training (print current iteration step, and checkpoint save), added option for verbose output (for debugging), added buffer size for output, full console output gets dumped on terminating training
This commit is contained in:
parent
5fcdb19f8b
commit
e7d0cfaa82
56
src/utils.py
56
src/utils.py
|
@ -435,34 +435,78 @@ def generate(
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
training_process = None
|
training_process = None
|
||||||
def run_training(config_path):
|
def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
||||||
try:
|
try:
|
||||||
print("Unloading TTS to save VRAM.")
|
print("Unloading TTS to save VRAM.")
|
||||||
global tts
|
global tts
|
||||||
del tts
|
del tts
|
||||||
tts = None
|
tts = None
|
||||||
|
trytorch.cuda.empty_cache()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
global training_process
|
global training_process
|
||||||
torch.multiprocessing.freeze_support()
|
torch.multiprocessing.freeze_support()
|
||||||
|
|
||||||
cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
do_gc()
|
||||||
|
|
||||||
|
cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
||||||
print("Spawning process: ", " ".join(cmd))
|
print("Spawning process: ", " ".join(cmd))
|
||||||
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
buffer=[]
|
|
||||||
|
# parse config to get its iteration
|
||||||
|
import yaml
|
||||||
|
with open(config_path, 'r') as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
it = 0
|
||||||
|
its = config['train']['niter']
|
||||||
|
|
||||||
|
checkpoint = 0
|
||||||
|
checkpoints = config['logger']['save_checkpoint_freq']
|
||||||
|
|
||||||
|
buffer_size = 8
|
||||||
|
open_state = False
|
||||||
|
training_started = False
|
||||||
|
|
||||||
|
yield " ".join(cmd)
|
||||||
|
|
||||||
|
buffer = []
|
||||||
|
infos = []
|
||||||
|
yields = True
|
||||||
for line in iter(training_process.stdout.readline, ""):
|
for line in iter(training_process.stdout.readline, ""):
|
||||||
buffer.append(f'[{datetime.now().isoformat()}] {line}')
|
buffer.append(f'{line}')
|
||||||
print(f"[Training] {line[:-1]}")
|
|
||||||
yield "".join(buffer[-8:])
|
# rip out iteration info
|
||||||
|
if not training_started:
|
||||||
|
if line.find('Start training from epoch') >= 0:
|
||||||
|
training_started = True
|
||||||
|
elif progress is not None:
|
||||||
|
if line.find(' 0%|') == 0:
|
||||||
|
open_state = True
|
||||||
|
elif line.find('100%|') == 0 and open_state:
|
||||||
|
open_state = False
|
||||||
|
it = it + 1
|
||||||
|
progress(it / float(its), f'[{it}/{its}] Training...')
|
||||||
|
elif line.find('INFO: [epoch:') >= 0:
|
||||||
|
infos.append(f'{line}')
|
||||||
|
elif line.find('Saving models and training states') >= 0:
|
||||||
|
checkpoint = checkpoint + 1
|
||||||
|
progress(checkpoint / float(checkpoints), f'[{checkpoint}/{checkpoints}] Saving checkpoint...')
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||||
|
yield "".join(buffer[-buffer_size:])
|
||||||
|
|
||||||
training_process.stdout.close()
|
training_process.stdout.close()
|
||||||
return_code = training_process.wait()
|
return_code = training_process.wait()
|
||||||
training_process = None
|
training_process = None
|
||||||
|
|
||||||
#if return_code:
|
#if return_code:
|
||||||
# raise subprocess.CalledProcessError(return_code, cmd)
|
# raise subprocess.CalledProcessError(return_code, cmd)
|
||||||
|
|
||||||
|
return "".join(buffer[-buffer_size:])
|
||||||
|
|
||||||
|
|
||||||
def stop_training():
|
def stop_training():
|
||||||
if training_process is None:
|
if training_process is None:
|
||||||
|
|
|
@ -350,6 +350,8 @@ def setup_gradio():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
|
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
|
||||||
|
verbose_training = gr.Checkbox(label="Verbose Training")
|
||||||
|
training_buffer_size = gr.Slider(label="Buffer Size", minimum=4, maximum=32, value=8)
|
||||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||||
start_training_button = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
stop_training_button = gr.Button(value="Stop")
|
stop_training_button = gr.Button(value="Stop")
|
||||||
|
@ -533,7 +535,11 @@ def setup_gradio():
|
||||||
outputs=training_configs
|
outputs=training_configs
|
||||||
)
|
)
|
||||||
start_training_button.click(run_training,
|
start_training_button.click(run_training,
|
||||||
inputs=training_configs,
|
inputs=[
|
||||||
|
training_configs,
|
||||||
|
verbose_training,
|
||||||
|
training_buffer_size,
|
||||||
|
],
|
||||||
outputs=training_output #console_output
|
outputs=training_output #console_output
|
||||||
)
|
)
|
||||||
stop_training_button.click(stop_training,
|
stop_training_button.click(stop_training,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user