From b6f7aa62647c42a7df0aa2e242531ab862cfa9a7 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 21 Feb 2023 04:22:11 +0000 Subject: [PATCH] fixes --- src/utils.py | 39 +++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/utils.py b/src/utils.py index bfaae25..2d0de77 100755 --- a/src/utils.py +++ b/src/utils.py @@ -415,7 +415,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress its = config['train']['niter'] checkpoint = 0 - checkpoints = config['logger']['save_checkpoint_freq'] / its + checkpoints = its / config['logger']['save_checkpoint_freq'] buffer_size = 8 open_state = False @@ -443,40 +443,35 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress elif progress is not None: if line.find(' 0%|') == 0: open_state = True - it_time_start = time.time() elif line.find('100%|') == 0 and open_state: - it_time_end = time.time() open_state = False it = it + 1 + it_time_end = time.time() it_time_delta = it_time_end-it_time_start - it_rate = f'[{"{:.3f}".format(it_time_delta)}s/it]' if it_time_delta >= 1 and it_time_delta != 0 else f'[{"{:.3f}".format(1/it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here + it_time_start = time.time() + it_rate = f'[{"{:.3f}".format(it_time_delta)}s/it]' if it_time_delta >= 1 else f'[{"{:.3f}".format(1/it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here progress(it / float(its), f'[{it}/{its}] {it_rate} Training... {status}') - - # try because I haven't tested this yet - try: - if line.find('INFO: [epoch:') >= 0: - # easily rip out our stats... - match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) - if match and len(match) > 0: - for k, v in match: - info[k] = float(v) - - # ...and returns our loss rate - # it would be nice for losses to be shown at every step - if 'loss_gpt_total' in info: - status = f"Total loss at step {int(info['step'])}: {info['loss_gpt_total']}" - except Exception as e: - pass - if line.find('Saving models and training states') >= 0: + if line.find('INFO: [epoch:') >= 0: + # easily rip out our stats... + match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) + if match and len(match) > 0: + for k, v in match: + info[k] = float(v) + + # ...and returns our loss rate + # it would be nice for losses to be shown at every step + if 'loss_gpt_total' in info: + status = f"Total loss at step {int(info['step'])}: {info['loss_gpt_total']}" + elif line.find('Saving models and training states') >= 0: checkpoint = checkpoint + 1 progress(checkpoint / float(checkpoints), f'[{checkpoint}/{checkpoints}] Saving checkpoint...') print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") - if verbose: + if verbose or not training_started: yield "".join(buffer[-buffer_size:]) training_process.stdout.close()