added '''estimating''' iterations until milestones (lr=[1, 0.5, 0.1] and final lr, very, very inaccurate because it uses instantaneous delta lr, I'll need to do a riemann sum later
This commit is contained in:
parent
1316331be3
commit
ce3866d0cd
57
src/utils.py
57
src/utils.py
|
@ -550,8 +550,11 @@ class TrainingState():
|
||||||
self.eta_hhmmss = "?"
|
self.eta_hhmmss = "?"
|
||||||
|
|
||||||
self.last_info_check_at = 0
|
self.last_info_check_at = 0
|
||||||
|
self.statistics = []
|
||||||
self.losses = []
|
self.losses = []
|
||||||
|
|
||||||
|
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||||
|
|
||||||
self.load_losses()
|
self.load_losses()
|
||||||
if keep_x_past_datasets > 0:
|
if keep_x_past_datasets > 0:
|
||||||
self.cleanup_old(keep=keep_x_past_datasets)
|
self.cleanup_old(keep=keep_x_past_datasets)
|
||||||
|
@ -578,7 +581,7 @@ class TrainingState():
|
||||||
highest_step = self.last_info_check_at
|
highest_step = self.last_info_check_at
|
||||||
|
|
||||||
if not update:
|
if not update:
|
||||||
self.losses = []
|
self.statistics = []
|
||||||
|
|
||||||
if use_tensorboard:
|
if use_tensorboard:
|
||||||
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
||||||
|
@ -596,9 +599,12 @@ class TrainingState():
|
||||||
if update and s.step <= self.last_info_check_at:
|
if update and s.step <= self.last_info_check_at:
|
||||||
continue
|
continue
|
||||||
highest_step = max( highest_step, s.step )
|
highest_step = max( highest_step, s.step )
|
||||||
|
self.statistics.append( { "step": s.step, "value": s.value, "type": key } )
|
||||||
|
|
||||||
|
if key == 'loss_gpt_total':
|
||||||
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
|
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Failed to parse event log:", log)
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -630,6 +636,9 @@ class TrainingState():
|
||||||
if update and int(k) <= self.last_info_check_at:
|
if update and int(k) <= self.last_info_check_at:
|
||||||
continue
|
continue
|
||||||
highest_step = max( highest_step, s.step )
|
highest_step = max( highest_step, s.step )
|
||||||
|
self.statistics.append({ "step": int(k), "value": infos[k][key], "type": key })
|
||||||
|
|
||||||
|
if key == "loss_gpt_total":
|
||||||
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
|
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
|
||||||
|
|
||||||
self.last_info_check_at = highest_step
|
self.last_info_check_at = highest_step
|
||||||
|
@ -739,20 +748,36 @@ class TrainingState():
|
||||||
metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
|
metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
|
||||||
|
|
||||||
if len(self.losses) >= 2:
|
if len(self.losses) >= 2:
|
||||||
delta_loss = self.losses[-2]["value"] - self.losses[-1]["value"]
|
# i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine
|
||||||
delta_step = self.losses[-2]["step"] - self.losses[-1]["step"]
|
d1_loss = self.losses[-1]["value"]
|
||||||
|
d2_loss = self.losses[-2]["value"]
|
||||||
|
dloss = d2_loss - d1_loss
|
||||||
|
|
||||||
inst_deriv = delta_loss / delta_step
|
d1_step = self.losses[-1]["step"]
|
||||||
est_loss = delta_loss + (self.its - self.it) * inst_deriv
|
d2_step = self.losses[-2]["step"]
|
||||||
metric_loss.append(f'Est. Final Loss: {"{:3f}".format(est_loss)}')
|
dstep = d2_step - d1_step
|
||||||
|
|
||||||
print(delta_loss, delta_step, inst_deriv, est_loss)
|
# don't bother if the loss went up
|
||||||
|
if dloss < 0:
|
||||||
|
its_remain = self.its - self.it
|
||||||
|
inst_deriv = dloss / dstep
|
||||||
|
|
||||||
|
next_milestone = None
|
||||||
|
for milestone in self.loss_milestones:
|
||||||
|
if d1_loss > milestone:
|
||||||
|
next_milestone = milestone
|
||||||
|
break
|
||||||
|
|
||||||
|
if next_milestone:
|
||||||
|
# tfw can do simple calculus but not basic algebra in my head
|
||||||
|
est_its = (next_milestone - d1_loss) * (dstep / dloss)
|
||||||
|
metric_loss.append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
|
||||||
|
else:
|
||||||
|
est_loss = inst_deriv * its_remain + d1_loss
|
||||||
|
metric_loss.append(f'Est. final loss: {"{:3f}".format(est_loss)}')
|
||||||
|
|
||||||
metric_loss = ", ".join(metric_loss)
|
metric_loss = ", ".join(metric_loss)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
message = f'[{metric_step}] [{metric_rate}] [ETA: {eta_hhmmss}] [{metric_loss}]'
|
message = f'[{metric_step}] [{metric_rate}] [ETA: {eta_hhmmss}] [{metric_loss}]'
|
||||||
|
|
||||||
if lapsed:
|
if lapsed:
|
||||||
|
@ -859,9 +884,9 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
|
||||||
|
|
||||||
def get_training_losses():
|
def get_training_losses():
|
||||||
global training_state
|
global training_state
|
||||||
if not training_state or not training_state.losses:
|
if not training_state or not training_state.statistics:
|
||||||
return
|
return
|
||||||
return pd.DataFrame(training_state.losses)
|
return pd.DataFrame(training_state.statistics)
|
||||||
|
|
||||||
def update_training_dataplot(config_path=None):
|
def update_training_dataplot(config_path=None):
|
||||||
global training_state
|
global training_state
|
||||||
|
@ -870,13 +895,13 @@ def update_training_dataplot(config_path=None):
|
||||||
if not training_state:
|
if not training_state:
|
||||||
if config_path:
|
if config_path:
|
||||||
training_state = TrainingState(config_path=config_path, start=False)
|
training_state = TrainingState(config_path=config_path, start=False)
|
||||||
if training_state.losses:
|
if training_state.statistics:
|
||||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics))
|
||||||
del training_state
|
del training_state
|
||||||
training_state = None
|
training_state = None
|
||||||
elif training_state.losses:
|
elif training_state.statistics:
|
||||||
training_state.load_losses()
|
training_state.load_losses()
|
||||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics))
|
||||||
|
|
||||||
return update
|
return update
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user