From 7f2da0f5fbf0e591137af5cb85a952ce4e89d1c9 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 10 Mar 2023 22:35:32 +0000 Subject: [PATCH] rewrote how AIVC gets training metrics (need to clean up later) --- src/train.py | 3 +- src/utils.py | 195 ++++++++++++++++++++++----------------------------- src/webui.py | 12 ++-- 3 files changed, 91 insertions(+), 119 deletions(-) diff --git a/src/train.py b/src/train.py index 72b3f38..4c5a6bb 100755 --- a/src/train.py +++ b/src/train.py @@ -18,6 +18,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') + parser.add_argument('--mode', type=str, default='none', help='mode') args = parser.parse_args() args.opt = " ".join(args.opt) # absolutely disgusting @@ -77,7 +78,7 @@ def train(yaml, launcher='none'): trainer.rank = torch.distributed.get_rank() torch.cuda.set_device(torch.distributed.get_rank()) - trainer.init(yaml, opt, launcher) + trainer.init(yaml, opt, launcher, '') trainer.do_training() if __name__ == "__main__": diff --git a/src/utils.py b/src/utils.py index 747c3ef..35fa6eb 100755 --- a/src/utils.py +++ b/src/utils.py @@ -594,6 +594,9 @@ class TrainingState(): self.it = 0 self.its = self.config['train']['niter'] + + self.step = 0 + self.steps = 1 self.epoch = 0 self.epochs = int(self.its*self.batch_size/self.dataset_size) @@ -653,13 +656,8 @@ class TrainingState(): self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) def load_statistics(self, update=False): - if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'): + if not os.path.isdir(f'{self.dataset_dir}/'): return - try: - from tensorboard.backend.event_processing import event_accumulator - use_tensorboard = True - except Exception as e: - use_tensorboard = False keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0'] infos = {} @@ -669,32 +667,44 @@ class TrainingState(): self.statistics['loss'] = [] self.statistics['lr'] = [] - logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ]) + logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ]) if update: logs = [logs[-1]] for log in logs: - ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0}) - ea.Reload() + with open(log, 'r', encoding="utf-8") as f: + lines = f.readlines() - scalars = ea.Tags()['scalars'] + for line in lines: + if line.find('INFO: Training Metrics:') >= 0: + data = line.split("INFO: Training Metrics:")[-1] + info = json.loads(data) - for key in keys: - if key not in scalars: + step = info['it'] + if update and step <= self.last_info_check_at: continue - try: - scalar = ea.Scalars(key) - for s in scalar: - if update and s.step <= self.last_info_check_at: - continue - highest_step = max( highest_step, s.step ) - target = 'lr' if key == "learning_rate_gpt_0" else 'loss' - self.statistics[target].append( { "step": s.step, "value": s.value, "type": key } ) - if key == 'loss_gpt_total': - self.losses.append( { "step": s.step, "value": s.value, "type": key } ) - except Exception as e: - pass + if 'lr' in info: + self.statistics['lr'].append({'step': step, 'value': info['lr'], 'type': 'learning_rate_gpt_0'}) + if 'loss_text_ce' in info: + self.statistics['loss'].append({'step': step, 'value': info['loss_text_ce'], 'type': 'loss_text_ce'}) + if 'loss_mel_ce' in info: + self.statistics['loss'].append({'step': step, 'value': info['loss_mel_ce'], 'type': 'loss_mel_ce'}) + if 'loss_gpt_total' in info: + self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'loss_gpt_total'}) + self.losses.append( self.statistics['loss'][-1] ) + + elif line.find('INFO: Validation Metrics:') >= 0: + data = line.split("INFO: Validation Metrics:")[-1] + + step = info['it'] + if update and step <= self.last_info_check_at: + continue + + if 'loss_text_ce' in info: + self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_text_ce'}) + if 'loss_mel_ce' in info: + self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_mel_ce'}) self.last_info_check_at = highest_step @@ -707,9 +717,8 @@ class TrainingState(): models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ]) states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ]) - - remove_models = models[:-keep] - remove_states = states[:-keep] + remove_models = models[:-2] + remove_states = states[:-2] for d in remove_models: path = f'{self.dataset_dir}/models/{d}_gpt.pth' @@ -727,8 +736,10 @@ class TrainingState(): percent = 0 message = None + if line.find('Finished training') >= 0: + self.killed = True # rip out iteration info - if not self.training_started: + elif not self.training_started: if line.find('Start training from epoch') >= 0: self.it_time_start = time.time() self.epoch_time_start = time.time() @@ -745,83 +756,57 @@ class TrainingState(): self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq']) else: lapsed = False - message = None - if line.find('INFO: [epoch:') >= 0: - info_line = line.split("INFO:")[-1] - # to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point - if ': nan' in info_line and not self.nan_detected: - self.nan_detected = self.it - # easily rip out our stats... - match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line) - if match and len(match) > 0: - for k, v in match: - self.info[k] = float(v.replace(",", "")) - - self.load_statistics(update=True) - should_return = True + # INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028} + if line.find('INFO: Training Metrics:') >= 0: + data = line.split("INFO: Training Metrics:")[-1] + self.info = json.loads(data) if 'epoch' in self.info: self.epoch = int(self.info['epoch']) - if 'iter' in self.info: - self.it = int(self.info['iter']) + if 'it' in self.info: + self.it = int(self.info['it']) + if 'step' in self.info: + self.step = int(self.info['step']) + if 'steps' in self.info: + self.steps = int(self.info['steps']) - elif line.find('Saving models and training states') >= 0: - self.checkpoint = self.checkpoint + 1 + if self.step == self.steps: + lapsed = True - percent = self.checkpoint / float(self.checkpoints) - message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...' - if progress is not None: - progress(percent, message) + if 'lr' in self.info: + self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate_gpt_0'}) + if 'loss_text_ce' in self.info: + self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_text_ce'], 'type': 'loss_text_ce'}) + if 'loss_mel_ce' in self.info: + self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_mel_ce'], 'type': 'loss_mel_ce'}) + if 'loss_gpt_total' in self.info: + self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'loss_gpt_total'}) + self.losses.append( self.statistics['loss'][-1] ) + + if 'iteration_rate' in self.info: + it_rate = self.info['iteration_rate'] + self.it_rate = f'{"{:.3f}".format(it_rate)}s/it' if it_rate >= 1 or it_rate == 0 else f'{"{:.3f}".format(1/it_rate)}it/s' - print(f'{"{:.3f}".format(percent*100)}% {message}') - self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') + self.metrics['step'] = [f"{self.epoch}/{self.epochs}"] + if self.epochs != self.its: + self.metrics['step'].append(f"{self.it}/{self.its}") + if self.steps > 1: + self.metrics['step'].append(f"{self.step}/{self.steps}") + self.metrics['step'] = ", ".join(self.metrics['step']) - self.cleanup_old(keep=keep_x_past_checkpoints) + should_return = True + elif line.find('INFO: Validation Metrics:') >= 0: + data = line.split("INFO: Validation Metrics:")[-1] - if line.find('%|') > 0: - match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) - if match and len(match) > 0: - match = match[0] - per_cent = int(match[0])/100.0 - progressbar = match[1] - step = int(match[2]) - steps = int(match[3]) - elapsed = match[4] - until = match[5] - rate = match[6] - - last_step = self.last_step - self.last_step = step - if last_step < step: - self.it = self.it + (step - last_step) - - if last_step == step and step == steps: - lapsed = True - - self.it_time_end = time.time() - self.it_time_delta = self.it_time_end-self.it_time_start - self.it_time_start = time.time() - self.it_taken = self.it_taken + 1 - if self.it_time_delta: - try: - rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 or self.it_time_delta == 0 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s' - self.it_rate = rate - except Exception as e: - pass - - self.metrics['step'] = [f"{self.epoch}/{self.epochs}"] - if self.epochs != self.its: - self.metrics['step'].append(f"{self.it}/{self.its}") - if steps > 1: - self.metrics['step'].append(f"{step}/{steps}") - self.metrics['step'] = ", ".join(self.metrics['step']) + if 'loss_text_ce' in self.info: + self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_text_ce'}) + if 'loss_mel_ce' in self.info: + self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_mel_ce'}) + should_return = True if lapsed: - self.epoch = self.epoch + 1 - self.it = int(self.epoch * (self.dataset_size / self.batch_size)) - self.epoch_time_end = time.time() self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start self.epoch_time_start = time.time() @@ -850,24 +835,16 @@ class TrainingState(): eta_hhmmss = "?" if self.eta_hhmmss: eta_hhmmss = self.eta_hhmmss - else: - try: - eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken) - eta = str(timedelta(seconds=int(eta))) - eta_hhmmss = eta - except Exception as e: - pass self.metrics['loss'] = [] - if 'learning_rate_gpt_0' in self.info: - self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["learning_rate_gpt_0"])}') + if 'lr' in self.info: + self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}') if len(self.losses) > 0: self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}') if len(self.losses) >= 2: - # """riemann sum""" but not really as this is for derivatives and not integrals deriv = 0 accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it loss_value = self.losses[-1]["value"] @@ -1296,10 +1273,6 @@ def optimize_training_settings( **kwargs ): iterations = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size']) - if settings['epochs'] < settings['print_rate']: - settings['print_rate'] = settings['epochs'] - messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {settings['print_rate']}") - if settings['epochs'] < settings['save_rate']: settings['save_rate'] = settings['epochs'] messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {settings['save_rate']}") @@ -1355,14 +1328,11 @@ def save_training_settings( **kwargs ): iterations_per_epoch = settings['iterations'] / settings['epochs'] - settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch) settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch) settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch) iterations_per_epoch = int(iterations_per_epoch) - if settings['print_rate'] < 1: - settings['print_rate'] = 1 if settings['save_rate'] < 1: settings['save_rate'] = 1 if settings['validation_rate'] < 1: @@ -1858,6 +1828,11 @@ def import_generate_settings(file="./config/generate.json"): res.update(settings) return res +def reset_generation_settings(): + with open(f'./config/generate.json', 'w', encoding="utf-8") as f: + f.write(json.dumps({}, indent='\t') ) + return import_generate_settings() + def read_generate_settings(file, read_latents=True): j = None latents = None diff --git a/src/webui.py b/src/webui.py index 8272362..50128b7 100755 --- a/src/webui.py +++ b/src/webui.py @@ -152,14 +152,11 @@ def import_generate_settings_proxy( file=None ): res = [] for k in GENERATE_SETTINGS_ARGS: res.append(settings[k] if k in settings else None) - + print(GENERATE_SETTINGS_ARGS) + print(settings) + print(res) return tuple(res) -def reset_generation_settings_proxy(): - with open(f'./config/generate.json', 'w', encoding="utf-8") as f: - f.write(json.dumps({}, indent='\t') ) - return import_generate_settings_proxy() - def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress ) return voice @@ -442,7 +439,6 @@ def setup_gradio(): TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0) TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0) with gr.Row(): - TRAINING_SETTINGS["print_rate"] = gr.Number(label="Print Frequency (in epochs)", value=5, precision=0) TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0) TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0) @@ -665,7 +661,7 @@ def setup_gradio(): ) reset_generation_settings_button.click( - fn=reset_generation_settings_proxy, + fn=reset_generation_settings, inputs=None, outputs=generate_settings )