1
0
Fork 0

added some output parsing during training (print current iteration step, and checkpoint save), added option for verbose output (for debugging), added buffer size for output, full console output gets dumped on terminating training

master
mrq 2023-02-19 05:05:30 +07:00
parent 5fcdb19f8b
commit e7d0cfaa82
2 changed files with 57 additions and 7 deletions

@ -435,34 +435,78 @@ def generate(
import subprocess import subprocess
training_process = None training_process = None
def run_training(config_path): def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
try: try:
print("Unloading TTS to save VRAM.") print("Unloading TTS to save VRAM.")
global tts global tts
del tts del tts
tts = None tts = None
trytorch.cuda.empty_cache()
except Exception as e: except Exception as e:
pass pass
global training_process global training_process
torch.multiprocessing.freeze_support() torch.multiprocessing.freeze_support()
do_gc()
cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path] cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
print("Spawning process: ", " ".join(cmd)) print("Spawning process: ", " ".join(cmd))
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
buffer=[]
# parse config to get its iteration
import yaml
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
it = 0
its = config['train']['niter']
checkpoint = 0
checkpoints = config['logger']['save_checkpoint_freq']
buffer_size = 8
open_state = False
training_started = False
yield " ".join(cmd)
buffer = []
infos = []
yields = True
for line in iter(training_process.stdout.readline, ""): for line in iter(training_process.stdout.readline, ""):
buffer.append(f'[{datetime.now().isoformat()}] {line}') buffer.append(f'{line}')
print(f"[Training] {line[:-1]}")
yield "".join(buffer[-8:]) # rip out iteration info
if not training_started:
if line.find('Start training from epoch') >= 0:
training_started = True
elif progress is not None:
if line.find(' 0%|') == 0:
open_state = True
elif line.find('100%|') == 0 and open_state:
open_state = False
it = it + 1
progress(it / float(its), f'[{it}/{its}] Training...')
elif line.find('INFO: [epoch:') >= 0:
infos.append(f'{line}')
elif line.find('Saving models and training states') >= 0:
checkpoint = checkpoint + 1
progress(checkpoint / float(checkpoints), f'[{checkpoint}/{checkpoints}] Saving checkpoint...')
if verbose:
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
yield "".join(buffer[-buffer_size:])
training_process.stdout.close() training_process.stdout.close()
return_code = training_process.wait() return_code = training_process.wait()
training_process = None training_process = None
#if return_code: #if return_code:
# raise subprocess.CalledProcessError(return_code, cmd) # raise subprocess.CalledProcessError(return_code, cmd)
return "".join(buffer[-buffer_size:])
def stop_training(): def stop_training():
if training_process is None: if training_process is None:

@ -350,6 +350,8 @@ def setup_gradio():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list()) training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
verbose_training = gr.Checkbox(label="Verbose Training")
training_buffer_size = gr.Slider(label="Buffer Size", minimum=4, maximum=32, value=8)
refresh_configs = gr.Button(value="Refresh Configurations") refresh_configs = gr.Button(value="Refresh Configurations")
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop") stop_training_button = gr.Button(value="Stop")
@ -533,7 +535,11 @@ def setup_gradio():
outputs=training_configs outputs=training_configs
) )
start_training_button.click(run_training, start_training_button.click(run_training,
inputs=training_configs, inputs=[
training_configs,
verbose_training,
training_buffer_size,
],
outputs=training_output #console_output outputs=training_output #console_output
) )
stop_training_button.click(stop_training, stop_training_button.click(stop_training,