diff --git a/src/utils.py b/src/utils.py index 241cb7f..f8d008f 100755 --- a/src/utils.py +++ b/src/utils.py @@ -627,7 +627,10 @@ class TrainingState(): self.nan_detected = False self.last_info_check_at = 0 - self.statistics = [] + self.statistics = { + 'loss': [], + 'lr': [], + } self.losses = [] self.metrics = { 'step': "", @@ -637,7 +640,7 @@ class TrainingState(): self.loss_milestones = [ 1.0, 0.15, 0.05 ] - self.load_losses() + self.load_statistics() if keep_x_past_checkpoints > 0: self.cleanup_old(keep=keep_x_past_checkpoints) if start: @@ -649,7 +652,7 @@ class TrainingState(): print("Spawning process: ", " ".join(self.cmd)) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) - def load_losses(self, update=False): + def load_statistics(self, update=False): if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'): return try: @@ -658,69 +661,40 @@ class TrainingState(): except Exception as e: use_tensorboard = False - keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce'] + keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0'] infos = {} highest_step = self.last_info_check_at if not update: - self.statistics = [] + self.statistics['loss'] = [] + self.statistics['lr'] = [] - if use_tensorboard: - logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ]) - if update: - logs = [logs[-1]] + logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ]) + if update: + logs = [logs[-1]] - for log in logs: - ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0}) - ea.Reload() + for log in logs: + ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0}) + ea.Reload() - for key in keys: - try: - scalar = ea.Scalars(key) - for s in scalar: - if update and s.step <= self.last_info_check_at: - continue - highest_step = max( highest_step, s.step ) - self.statistics.append( { "step": s.step, "value": s.value, "type": key } ) - - if key == 'loss_gpt_total': - self.losses.append( { "step": s.step, "value": s.value, "type": key } ) - except Exception as e: - pass + scalars = ea.Tags()['scalars'] - else: - logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ]) - if update: - logs = [logs[-1]] + for key in keys: + if key not in scalars: + continue - for log in logs: - with open(log, 'r', encoding="utf-8") as f: - lines = f.readlines() - for line in lines: - if line.find('INFO: [epoch:') >= 0: - # easily rip out our stats... - match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line) - if not match or len(match) == 0: + try: + scalar = ea.Scalars(key) + for s in scalar: + if update and s.step <= self.last_info_check_at: continue - - info = {} - for k, v in match: - info[k] = float(v.replace(",", "")) - - if 'iter' in info: - it = info['iter'] - infos[it] = info - - for k in infos: - if 'loss_gpt_total' in infos[k]: - for key in keys: - if update and int(k) <= self.last_info_check_at: - continue - highest_step = max( highest_step, s.step ) - self.statistics.append({ "step": int(k), "value": infos[k][key], "type": key }) - - if key == "loss_gpt_total": - self.losses.append({ "step": int(k), "value": infos[k][key], "type": key }) + highest_step = max( highest_step, s.step ) + target = 'lr' if key == "learning_rate_gpt_0" else 'loss' + self.statistics[target].append( { "step": s.step, "value": s.value, "type": key } ) + if key == 'loss_gpt_total': + self.losses.append( { "step": s.step, "value": s.value, "type": key } ) + except Exception as e: + pass self.last_info_check_at = highest_step @@ -784,7 +758,7 @@ class TrainingState(): for k, v in match: self.info[k] = float(v.replace(",", "")) - self.load_losses(update=True) + self.load_statistics(update=True) should_return = True if 'epoch' in self.info: @@ -1003,20 +977,26 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress def update_training_dataplot(config_path=None): global training_state - update = None + losses = None + lrs = None if not training_state: if config_path: training_state = TrainingState(config_path=config_path, start=False) - if training_state.statistics: - update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,) + if len(training_state.statistics['loss']) > 0: + losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,) + if len(training_state.statistics['lr']) > 0: + lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,) del training_state training_state = None - elif training_state.statistics: - training_state.load_losses() - update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,) + else: + training_state.load_statistics() + if len(training_state.statistics['loss']) > 0: + losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,) + if len(training_state.statistics['lr']) > 0: + lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,) - return update + return (losses, lrs) def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)): global training_state @@ -1363,9 +1343,11 @@ def save_training_settings( **kwargs ): settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps") - settings['print_rate'] = int(settings['print_rate'] * settings['iterations'] / settings['epochs']) - settings['save_rate'] = int(settings['save_rate'] * settings['iterations'] / settings['epochs']) - settings['validation_rate'] = int(settings['validation_rate'] * settings['iterations'] / settings['epochs']) + iterations_per_epoch = int(settings['iterations'] / settings['epochs']) + + settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch) + settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch) + settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch) settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size']) @@ -1407,16 +1389,31 @@ def save_training_settings( **kwargs ): elif isinstance(settings['learning_rate_schedule'],str): settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule']) - settings['learning_rate_schedule'] = schedule_learning_rate( settings['iterations'] / settings['epochs'], settings['learning_rate_schedule'] ) + settings['learning_rate_schedule'] = schedule_learning_rate( iterations_per_epoch, settings['learning_rate_schedule'] ) learning_rate_schema.append(f" gen_lr_steps: {settings['learning_rate_schedule']}") learning_rate_schema.append(f" lr_gamma: 0.5") elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart": - learning_rate_schema.append(f" T_period: [120000, 120000, 120000]") - learning_rate_schema.append(f" warmup: 10000") - learning_rate_schema.append(f" eta_min: .01") - learning_rate_schema.append(f" restarts: [140000, 280000]") - learning_rate_schema.append(f" restart_weights: [.5, .25]") + epochs = settings['epochs'] + restarts = int(epochs / 2) + + if 'learning_rate_period' not in settings: + settings['learning_rate_period'] = [ iterations_per_epoch for x in range(epochs) ] + if 'learning_rate_warmup' not in settings: + settings['learning_rate_warmup'] = 0 + if 'learning_rate_min' not in settings: + settings['learning_rate_min'] = 1e-07 + if 'learning_rate_restarts' not in settings: + settings['learning_rate_restarts'] = [ iterations_per_epoch * (x+1) * 2 for x in range(restarts) ] # [52, 104, 156, 208] + if 'learning_rate_restart_weights' not in settings: + settings['learning_rate_restart_weights'] = [ ( restarts - x - 1 ) / restarts for x in range(restarts) ] # [.75, .5, .25, .125] + settings['learning_rate_restart_weights'][-1] = settings['learning_rate_restart_weights'][-2] * 0.5 + + learning_rate_schema.append(f" T_period: {settings['learning_rate_period']}") + learning_rate_schema.append(f" warmup: !!float {settings['learning_rate_warmup']}") + learning_rate_schema.append(f" eta_min: !!float {settings['learning_rate_min']}") + learning_rate_schema.append(f" restarts: {settings['learning_rate_restarts']}") + learning_rate_schema.append(f" restart_weights: {settings['learning_rate_restart_weights']}") settings['learning_rate_scheme'] = "\n".join(learning_rate_schema) """ diff --git a/src/webui.py b/src/webui.py index 6691759..8b5ddae 100755 --- a/src/webui.py +++ b/src/webui.py @@ -430,29 +430,36 @@ def setup_gradio(): with gr.Row(): with gr.Column(): training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list()) + refresh_configs = gr.Button(value="Refresh Configurations") + training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) + verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) + + training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) with gr.Row(): - refresh_configs = gr.Button(value="Refresh Configurations") + start_training_button = gr.Button(value="Train") + stop_training_button = gr.Button(value="Stop") + reconnect_training_button = gr.Button(value="Reconnect") + with gr.Column(): training_loss_graph = gr.LinePlot(label="Training Metrics", x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], - width=600, + width=500, + height=350, + ) + training_lr_graph = gr.LinePlot(label="Training Metrics", + x="step", + y="value", + title="Training Metrics", + color="type", + tooltip=['step', 'value', 'type'], + width=500, height=350, ) view_losses = gr.Button(value="View Losses") - - with gr.Column(): - training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) - verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) - - training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) - with gr.Row(): - start_training_button = gr.Button(value="Train") - stop_training_button = gr.Button(value="Stop") - reconnect_training_button = gr.Button(value="Reconnect") with gr.Tab("Settings"): with gr.Row(): exec_inputs = [] @@ -650,6 +657,7 @@ def setup_gradio(): inputs=None, outputs=[ training_loss_graph, + training_lr_graph, ], show_progress=False, ) @@ -661,6 +669,7 @@ def setup_gradio(): ], outputs=[ training_loss_graph, + training_lr_graph, ], )