From 1a9d159b2a7ac84615b37b4b98d289e280510f02 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 4 Mar 2023 17:37:08 +0000 Subject: [PATCH] forgot to add 'bs / gradient accum < 2 clamp validation logic --- src/utils.py | 89 +++++++++++++++++++++++++++++++++++++--------------- src/webui.py | 5 +-- 2 files changed, 65 insertions(+), 29 deletions(-) diff --git a/src/utils.py b/src/utils.py index 46abc92..5e49a8e 100755 --- a/src/utils.py +++ b/src/utils.py @@ -506,6 +506,8 @@ class TrainingState(): with open(config_path, 'r') as file: self.config = yaml.safe_load(file) + self.killed = False + self.dataset_dir = f"./training/{self.config['name']}/" self.batch_size = self.config['datasets']['train']['batch_size'] self.dataset_path = self.config['datasets']['train']['path'] @@ -527,7 +529,6 @@ class TrainingState(): self.training_started = False self.info = {} - self.status = "..." self.epoch_rate = "" self.epoch_time_start = 0 @@ -651,10 +652,12 @@ class TrainingState(): print("Removing", path) os.remove(path) - def parse(self, line, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=None ): + def parse(self, line, verbose=False, keep_x_past_datasets=0, buffer_size=8, progress=None ): self.buffer.append(f'{line}') should_return = False + percent = 0 + message = None # rip out iteration info if not self.training_started: @@ -679,7 +682,7 @@ class TrainingState(): match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) if match and len(match) > 0: match = match[0] - percent = int(match[0])/100.0 + per_cent = int(match[0])/100.0 progressbar = match[1] step = int(match[2]) steps = int(match[3]) @@ -698,15 +701,40 @@ class TrainingState(): self.it_time_end = time.time() self.it_time_delta = self.it_time_end-self.it_time_start self.it_time_start = time.time() + self.it_taken = self.it_taken + 1 try: rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s' self.it_rate = rate except Exception as e: pass - last_loss = "" + + metric_step = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"] + metric_step = ", ".join(metric_step) + + metric_rate = [] + if self.epoch_rate: + metric_rate.append(self.epoch_rate) + if self.it_rate: + metric_rate.append(self.it_rate) + metric_rate = ", ".join(metric_rate) + + eta_hhmmss = "?" + if self.eta_hhmmss: + eta_hhmmss = self.eta_hhmmss + else: + try: + eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken) + eta = str(timedelta(seconds=int(eta))) + eta_hhmmss = eta + except Exception as e: + pass + + metric_loss = [] if len(self.losses) > 0: - last_loss = f'[Loss @ it. {self.losses[-1]["step"]}: {self.losses[-1]["value"]}]' - message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] {last_loss} [ETA: {self.eta_hhmmss}]' + metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}') + metric_loss = ", ".join(metric_loss) + + message = f'[{metric_step}] [{metric_rate}] [{metric_loss}] [ETA: {eta_hhmmss}]' if lapsed: self.epoch = self.epoch + 1 @@ -740,17 +768,9 @@ class TrainingState(): if match and len(match) > 0: for k, v in match: self.info[k] = float(v.replace(",", "")) - - if 'loss_gpt_total' in self.info: - self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" - """ - self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" }) - self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" }) - self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" }) - """ - should_return = True self.load_losses(update=True) + should_return = True elif line.find('Saving models and training states') >= 0: self.checkpoint = self.checkpoint + 1 @@ -769,10 +789,18 @@ class TrainingState(): should_return = True self.buffer = self.buffer[-buffer_size:] + + result = None if should_return: - return "".join(self.buffer) + result = "".join(self.buffer) if not self.training_started else message -def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)): + return ( + result, + percent, + message, + ) + +def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)): global training_state if training_state and training_state.process: return "Training already in progress" @@ -787,11 +815,10 @@ def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_ training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus) for line in iter(training_state.process.stdout.readline, ""): - - res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, keep_x_past_datasets=keep_x_past_datasets, progress=progress ) + result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress ) print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") - if res: - yield res + if result: + yield result if training_state: training_state.process.stdout.close() @@ -824,15 +851,16 @@ def update_training_dataplot(config_path=None): return update -def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): +def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)): global training_state if not training_state or not training_state.process: return "Training not in progress" for line in iter(training_state.process.stdout.readline, ""): - res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) - if res: - yield res + result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress ) + print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") + if result: + yield result def stop_training(): global training_state @@ -845,6 +873,7 @@ def stop_training(): training_state.process.send_signal(signal.SIGINT) return_code = training_state.process.wait() training_state = None + print("Killed training process.") return f"Training cancelled: {return_code}" def get_halfp_model_path(): @@ -966,8 +995,18 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni if gradient_accumulation_size == 0: gradient_accumulation_size = 1 + + if batch_size / gradient_accumulation_size < 2: + gradient_accumulation_size = int(batch_size / 2) + if gradient_accumulation_size == 0: + gradient_accumulation_size = 1 + + messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}") elif batch_size % gradient_accumulation_size != 0: gradient_accumulation_size = int(batch_size / gradient_accumulation_size) + if gradient_accumulation_size == 0: + gradient_accumulation_size = 1 + messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}") iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size) diff --git a/src/webui.py b/src/webui.py index db90051..6bf8224 100755 --- a/src/webui.py +++ b/src/webui.py @@ -535,9 +535,8 @@ def setup_gradio(): verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) with gr.Row(): - training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8) training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) - training_gpu_count = gr.Number(label="GPUs", value=1) + training_gpu_count = gr.Number(label="GPUs", value=1) with gr.Row(): start_training_button = gr.Button(value="Train") stop_training_button = gr.Button(value="Stop") @@ -746,7 +745,6 @@ def setup_gradio(): training_configs, verbose_training, training_gpu_count, - training_buffer_size, training_keep_x_past_datasets, ], outputs=[ @@ -779,7 +777,6 @@ def setup_gradio(): reconnect_training_button.click(reconnect_training, inputs=[ verbose_training, - training_buffer_size, ], outputs=training_output #console_output )