added '''estimating''' iterations until milestones (lr=[1, 0.5, 0.1] and final lr, very, very inaccurate because it uses instantaneous delta lr, I'll need to do a riemann sum later

This commit is contained in:
mrq 2023-03-05 06:45:07 +00:00
parent 1316331be3
commit ce3866d0cd

View File

@ -550,8 +550,11 @@ class TrainingState():
self.eta_hhmmss = "?"
self.last_info_check_at = 0
self.statistics = []
self.losses = []
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
self.load_losses()
if keep_x_past_datasets > 0:
self.cleanup_old(keep=keep_x_past_datasets)
@ -578,7 +581,7 @@ class TrainingState():
highest_step = self.last_info_check_at
if not update:
self.losses = []
self.statistics = []
if use_tensorboard:
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
@ -596,9 +599,12 @@ class TrainingState():
if update and s.step <= self.last_info_check_at:
continue
highest_step = max( highest_step, s.step )
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
self.statistics.append( { "step": s.step, "value": s.value, "type": key } )
if key == 'loss_gpt_total':
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
except Exception as e:
print("Failed to parse event log:", log)
pass
else:
@ -630,7 +636,10 @@ class TrainingState():
if update and int(k) <= self.last_info_check_at:
continue
highest_step = max( highest_step, s.step )
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
self.statistics.append({ "step": int(k), "value": infos[k][key], "type": key })
if key == "loss_gpt_total":
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
self.last_info_check_at = highest_step
@ -739,20 +748,36 @@ class TrainingState():
metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
if len(self.losses) >= 2:
delta_loss = self.losses[-2]["value"] - self.losses[-1]["value"]
delta_step = self.losses[-2]["step"] - self.losses[-1]["step"]
# i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine
d1_loss = self.losses[-1]["value"]
d2_loss = self.losses[-2]["value"]
dloss = d2_loss - d1_loss
d1_step = self.losses[-1]["step"]
d2_step = self.losses[-2]["step"]
dstep = d2_step - d1_step
inst_deriv = delta_loss / delta_step
est_loss = delta_loss + (self.its - self.it) * inst_deriv
metric_loss.append(f'Est. Final Loss: {"{:3f}".format(est_loss)}')
# don't bother if the loss went up
if dloss < 0:
its_remain = self.its - self.it
inst_deriv = dloss / dstep
print(delta_loss, delta_step, inst_deriv, est_loss)
next_milestone = None
for milestone in self.loss_milestones:
if d1_loss > milestone:
next_milestone = milestone
break
if next_milestone:
# tfw can do simple calculus but not basic algebra in my head
est_its = (next_milestone - d1_loss) * (dstep / dloss)
metric_loss.append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
else:
est_loss = inst_deriv * its_remain + d1_loss
metric_loss.append(f'Est. final loss: {"{:3f}".format(est_loss)}')
metric_loss = ", ".join(metric_loss)
message = f'[{metric_step}] [{metric_rate}] [ETA: {eta_hhmmss}] [{metric_loss}]'
if lapsed:
@ -859,9 +884,9 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
def get_training_losses():
global training_state
if not training_state or not training_state.losses:
if not training_state or not training_state.statistics:
return
return pd.DataFrame(training_state.losses)
return pd.DataFrame(training_state.statistics)
def update_training_dataplot(config_path=None):
global training_state
@ -870,13 +895,13 @@ def update_training_dataplot(config_path=None):
if not training_state:
if config_path:
training_state = TrainingState(config_path=config_path, start=False)
if training_state.losses:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
if training_state.statistics:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics))
del training_state
training_state = None
elif training_state.losses:
elif training_state.statistics:
training_state.load_losses()
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics))
return update
@ -943,7 +968,7 @@ def whisper_transcribe( file, language=None ):
if not args.whisper_cpp:
if not language:
language = None
return whisper_model.transcribe(file, language=language)
res = whisper_model.transcribe(file)