From 1e0fec4358a11bac2cf20cef05d3e890638ee7c5 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 23 Feb 2023 23:22:23 +0000 Subject: [PATCH] god i finally found some time and focus: reworded print/save freq per epoch => print/save freq (in epochs), added import config button to reread the last used settings (will check for the output folder's configs first, then the generated ones) and auto-grab the last resume state (if available), some other cleanups i genuinely don't remember what I did when I spaced out for 20 minutes --- src/utils.py | 44 ++++++++++++--------- src/webui.py | 108 ++++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 118 insertions(+), 34 deletions(-) diff --git a/src/utils.py b/src/utils.py index 3f7cac9..f80e231 100755 --- a/src/utils.py +++ b/src/utils.py @@ -438,7 +438,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm # superfluous, but it cleans up some things class TrainingState(): - def __init__(self, config_path, buffer_size=8): + def __init__(self, config_path): self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path] # parse config to get its iteration @@ -465,7 +465,7 @@ class TrainingState(): self.training_started = False self.info = {} - self.status = "" + self.status = "..." self.epoch_rate = "" self.epoch_time_start = 0 @@ -491,7 +491,7 @@ class TrainingState(): match = re.findall(r'iter: ([\d,]+)', line) if match and len(match) > 0: self.it = int(match[0].replace(",", "")) - elif progress is not None: + else: if line.find('%|') > 0 and not self.open_state: self.open_state = True elif line.find('100%|') == 0 and self.open_state: @@ -505,7 +505,12 @@ class TrainingState(): self.eta = (self.epochs - self.epoch) * self.epoch_time_delta self.eta_hhmmss = str(timedelta(seconds=int(self.eta))) - progress(self.epoch / float(self.epochs), f'[{self.epoch}/{self.epochs}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} Training... {self.status}') + percent = self.epoch / float(self.epochs) + message = f'[{self.epoch}/{self.epochs}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} {self.status}' + print(f'{"{:.3f}".format(percent*100)}% {message}') + if progress is not None: + progress(percent, message) + self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') if line.find('INFO: [epoch:') >= 0: # easily rip out our stats... @@ -516,12 +521,20 @@ class TrainingState(): if 'loss_gpt_total' in self.info: self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" + print(self.status) + self.buffer.append(self.status) elif line.find('Saving models and training states') >= 0: self.checkpoint = self.checkpoint + 1 - progress(self.checkpoint / float(self.checkpoints), f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...') - + percent = self.checkpoint / float(self.checkpoints) + message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...' + print(f'{"{:.3f}".format(percent*100)}% {message}') + if progress is not None: + progress(percent, message) + self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') + + self.buffer = self.buffer[-buffer_size:] if verbose or not self.training_started: - return "".join(self.buffer[-buffer_size:]) + return "".join(self.buffer) def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): global training_state @@ -535,25 +548,22 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress unload_whisper() unload_voicefixer() - training_state = TrainingState(config_path=config_path, buffer_size=buffer_size) + training_state = TrainingState(config_path=config_path) for line in iter(training_state.process.stdout.readline, ""): - print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) + print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") if res: yield res training_state.process.stdout.close() return_code = training_state.process.wait() - output = "".join(training_state.buffer[-buffer_size:]) training_state = None #if return_code: # raise subprocess.CalledProcessError(return_code, cmd) - return output - def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): global training_state if not training_state or not training_state.process: @@ -563,10 +573,6 @@ def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Pr res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) if res: yield res - - output = "".join(training_state.buffer[-buffer_size:]) - - return output def stop_training(): global training_process @@ -644,7 +650,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ] def schedule_learning_rate( iterations ): return [int(iterations * d) for d in EPOCH_SCHEDULE] -def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ): +def optimize_training_settings( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ): name = f"{voice}-finetune" dataset_name = f"{voice}-train" dataset_path = f"./training/{voice}/train.txt" @@ -694,9 +700,9 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)") return ( - batch_size, learning_rate, learning_rate_schedule, + batch_size, mega_batch_factor, print_rate, save_rate, @@ -704,7 +710,7 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate messages ) -def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None ): +def save_training_settings( iterations=None, learning_rate=None, learning_rate_schedule=None, batch_size=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None ): settings = { "iterations": iterations if iterations else 500, "batch_size": batch_size if batch_size else 64, diff --git a/src/webui.py b/src/webui.py index 3ebae01..a9aaf2a 100755 --- a/src/webui.py +++ b/src/webui.py @@ -200,7 +200,65 @@ def optimize_training_settings_proxy( *args, **kwargs ): "\n".join(tup[7]) ) -def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ): +def import_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ): + indir = f'./training/{voice}/' + outdir = f'./training/{voice}-finetune/' + + in_config_path = f"{indir}/train.yaml" + out_configs = sorted([d[:-5] for d in os.listdir(outdir) if d[-5:] == ".yaml" ]) + if len(out_configs) > 0: + out_config_path = f'{outdir}/{out_configs[-1]}.yaml' + + config_path = out_config_path if out_config_path else in_config_path + + messages = [] + with open(config_path, 'r') as file: + config = yaml.safe_load(file) + messages.append(f"Importing from: {config_path}") + + dataset_path = f"./training/{voice}/train.txt" + with open(dataset_path, 'r', encoding="utf-8") as f: + lines = len(f.readlines()) + messages.append(f"Basing epoch size to {lines} lines") + + batch_size = config['datasets']['train']['batch_size'] + mega_batch_factor = config['train']['mega_batch_factor'] + + iterations = config['train']['niter'] + steps_per_iteration = int(lines / batch_size) + epochs = int(iterations / steps_per_iteration) + + + learning_rate = config['steps']['gpt_train']['optimizer_params']['lr'] + learning_rate_schedule = [ int(x / steps_per_iteration) for x in config['train']['gen_lr_steps'] ] + + + print_rate = int(config['logger']['print_freq'] / steps_per_iteration) + save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration) + + statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES + resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ]) + + if len(resumes) > 0: + resume_path = f'{statedir}/{resumes[-1]}.state' + messages.append(f"Latest resume found: {resume_path}") + + messages = "\n".join(messages) + + return ( + epochs, + learning_rate, + learning_rate_schedule, + batch_size, + mega_batch_factor, + print_rate, + save_rate, + resume_path, + messages + ) + + +def save_training_settings_proxy( epochs, learning_rate, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, voice ): name = f"{voice}-finetune" dataset_name = f"{voice}-train" dataset_path = f"./training/{voice}/train.txt" @@ -333,8 +391,9 @@ def setup_gradio(): repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty") cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=4, label="Conditioning-Free K") with gr.Column(): - submit = gr.Button(value="Generate") - stop = gr.Button(value="Stop") + with gr.Row(): + submit = gr.Button(value="Generate") + stop = gr.Button(value="Stop") generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"], visible=False) source_sample = gr.Audio(label="Source Sample", visible=False) @@ -392,30 +451,45 @@ def setup_gradio(): with gr.Column(): training_settings = [ gr.Number(label="Epochs", value=500, precision=0), - gr.Number(label="Batch Size", value=128, precision=0), - gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6), - gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)), - gr.Number(label="Mega Batch Factor", value=4, precision=0), - gr.Number(label="Print Frequency per Epoch", value=5, precision=0), - gr.Number(label="Save Frequency per Epoch", value=5, precision=0), + ] + with gr.Row(): + training_settings = training_settings + [ + gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6), + gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)), + ] + with gr.Row(): + training_settings = training_settings + [ + gr.Number(label="Batch Size", value=128, precision=0), + gr.Number(label="Mega Batch Factor", value=4, precision=0), + ] + with gr.Row(): + training_settings = training_settings + [ + gr.Number(label="Print Frequency (in epochs)", value=5, precision=0), + gr.Number(label="Save Frequency (in epochs)", value=5, precision=0), + ] + training_settings = training_settings + [ gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"), gr.Checkbox(label="Half Precision", value=False), ] dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" ) training_settings = training_settings + [ dataset_list ] - refresh_dataset_list = gr.Button(value="Refresh Dataset List") + with gr.Row(): + refresh_dataset_list = gr.Button(value="Refresh Dataset List") + import_dataset_button = gr.Button(value="Import Dataset") with gr.Column(): save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) - optimize_yaml_button = gr.Button(value="Validate Training Configuration") - save_yaml_button = gr.Button(value="Save Training Configuration") + with gr.Row(): + optimize_yaml_button = gr.Button(value="Validate Training Configuration") + save_yaml_button = gr.Button(value="Save Training Configuration") with gr.Tab("Run Training"): with gr.Row(): with gr.Column(): training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list()) refresh_configs = gr.Button(value="Refresh Configurations") - start_training_button = gr.Button(value="Train") - stop_training_button = gr.Button(value="Stop") - reconnect_training_button = gr.Button(value="Reconnect") + with gr.Row(): + start_training_button = gr.Button(value="Train") + stop_training_button = gr.Button(value="Stop") + reconnect_training_button = gr.Button(value="Reconnect") with gr.Column(): training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) verbose_training = gr.Checkbox(label="Verbose Console Output") @@ -641,6 +715,10 @@ def setup_gradio(): inputs=training_settings, outputs=training_settings[1:8] + [save_yaml_output] #console_output ) + import_dataset_button.click(import_training_settings_proxy, + inputs=training_settings, + outputs=training_settings[:8] + [save_yaml_output] #console_output + ) save_yaml_button.click(save_training_settings_proxy, inputs=training_settings, outputs=save_yaml_output #console_output