csome adjustments to the training output parser, now updates per iteration for really large batches (like the one I'm doing for a dataset size of 19420)

This commit is contained in:
mrq 2023-02-25 13:55:25 +00:00
parent d5d8821a9d
commit 8b4da29d5f

View File

@ -470,14 +470,21 @@ class TrainingState():
self.epoch_rate = "" self.epoch_rate = ""
self.epoch_time_start = 0 self.epoch_time_start = 0
self.epoch_time_end = 0 self.epoch_time_end = 0
self.it_rate = ""
self.it_time_start = 0
self.it_time_end = 0
self.last_step = 0
self.eta = "?" self.eta = "?"
self.eta_hhmmss = "?" self.eta_hhmmss = "?"
print("Spawning process: ", " ".join(self.cmd)) print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def parse(self, line, verbose=False, buffer_size=8, progress=None): def parse(self, line, verbose=False, buffer_size=8, progress=None, owner=True):
self.buffer.append(f'{line}') if owner:
self.buffer.append(f'{line}')
# rip out iteration info # rip out iteration info
if not self.training_started: if not self.training_started:
@ -492,47 +499,97 @@ class TrainingState():
if match and len(match) > 0: if match and len(match) > 0:
self.it = int(match[0].replace(",", "")) self.it = int(match[0].replace(",", ""))
else: else:
if line.find('%|') > 0 and not self.open_state: lapsed = line.find('100%|') == 0 and self.open_state
self.open_state = True
elif line.find('100%|') == 0 and self.open_state:
self.open_state = False
self.epoch = self.epoch + 1
self.epoch_time_end = time.time() if line.find('%|') > 0:
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start match = re.findall(r' +?(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
self.epoch_time_start = time.time() if match and len(match) > 0:
self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here match = match[0]
self.eta = (self.epochs - self.epoch) * self.epoch_time_delta percent = int(match[0])/100.0
self.eta_hhmmss = str(timedelta(seconds=int(self.eta))) progressbar = match[1]
step = int(match[2])
steps = int(match[3])
elapsed = match[4]
until = match[5]
rate = match[6]
epoch_percent = self.epoch / float(self.epochs)
if owner:
last_step = self.last_step
self.last_step = step
if last_step < step:
self.it = self.it + (step - last_step)
if last_step > step and step == 0:
lapsed = True
self.it_time_end = time.time()
self.it_time_delta = self.it_time_end-self.it_time_start
self.it_time_start = time.time()
self.it_rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]'
self.eta = (self.its - self.it) * self.it_time_delta
self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}'
if progress is not None:
progress(epoch_percent, message)
if owner:
# print(f'{"{:.3f}".format(percent*100)}% {message}')
self.buffer.append(f'[{"{:.3f}".format(epoch_percent*100)}% / {"{:.3f}".format(percent*100)}%] {message}')
if line.find('%|') > 0 and not self.open_state:
if owner:
self.open_state = True
elif lapsed:
if owner:
self.open_state = False
self.epoch = self.epoch + 1
self.it = int(self.epoch * (self.dataset_size / self.batch_size))
self.epoch_time_end = time.time()
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
self.epoch_time_start = time.time()
self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here
self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
percent = self.epoch / float(self.epochs) percent = self.epoch / float(self.epochs)
message = f'[{self.epoch}/{self.epochs}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} {self.status}' message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}'
print(f'{"{:.3f}".format(percent*100)}% {message}')
if progress is not None: if progress is not None:
progress(percent, message) progress(percent, message)
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
if owner:
print(f'{"{:.3f}".format(percent*100)}% {message}')
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
if line.find('INFO: [epoch:') >= 0: if line.find('INFO: [epoch:') >= 0:
# easily rip out our stats... if owner:
match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) # easily rip out our stats...
if match and len(match) > 0: match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
for k, v in match: if match and len(match) > 0:
self.info[k] = float(v) for k, v in match:
self.info[k] = float(v)
if 'loss_gpt_total' in self.info:
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" if 'loss_gpt_total' in self.info:
print(self.status) self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
self.buffer.append(self.status) print(self.status)
self.buffer.append(self.status)
elif line.find('Saving models and training states') >= 0: elif line.find('Saving models and training states') >= 0:
self.checkpoint = self.checkpoint + 1 if owner:
self.checkpoint = self.checkpoint + 1
percent = self.checkpoint / float(self.checkpoints) percent = self.checkpoint / float(self.checkpoints)
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...' message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
print(f'{"{:.3f}".format(percent*100)}% {message}')
if progress is not None: if progress is not None:
progress(percent, message) progress(percent, message)
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') if owner:
print(f'{"{:.3f}".format(percent*100)}% {message}')
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
self.buffer = self.buffer[-buffer_size:] if owner:
self.buffer = self.buffer[-buffer_size:]
if verbose or not self.training_started: if verbose or not self.training_started:
return "".join(self.buffer) return "".join(self.buffer)
@ -552,7 +609,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
for line in iter(training_state.process.stdout.readline, ""): for line in iter(training_state.process.stdout.readline, ""):
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress, owner=True )
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
if res: if res:
yield res yield res
@ -565,13 +622,13 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
#if return_code: #if return_code:
# raise subprocess.CalledProcessError(return_code, cmd) # raise subprocess.CalledProcessError(return_code, cmd)
def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
global training_state global training_state
if not training_state or not training_state.process: if not training_state or not training_state.process:
return "Training not in progress" return "Training not in progress"
for line in iter(training_state.process.stdout.readline, ""): for line in iter(training_state.process.stdout.readline, ""):
res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress, owner=True )
if res: if res:
yield res yield res