This commit is contained in:
mrq 2023-03-05 19:58:15 +00:00
parent 8094401a6d
commit b2e89d8da3

View File

@ -817,13 +817,15 @@ class TrainingState():
# """riemann sum""" but not really as this is for derivatives and not integrals # """riemann sum""" but not really as this is for derivatives and not integrals
deriv = 0 deriv = 0
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
loss_value = self.losses[-1]["value"]
for i in range(accum_length): for i in range(accum_length):
d1_loss = self.losses[-i-1]["value"] d1_loss = self.losses[accum_length-i-1]["value"]
d2_loss = self.losses[-i-2]["value"] d2_loss = self.losses[accum_length-i-2]["value"]
dloss = (d2_loss - d1_loss) dloss = (d2_loss - d1_loss)
d1_step = self.losses[-i-1]["step"] d1_step = self.losses[accum_length-i-1]["step"]
d2_step = self.losses[-i-2]["step"] d2_step = self.losses[accum_length-i-2]["step"]
dstep = (d2_step - d1_step) dstep = (d2_step - d1_step)
if dstep == 0: if dstep == 0:
@ -837,17 +839,17 @@ class TrainingState():
if deriv != 0: # dloss < 0: if deriv != 0: # dloss < 0:
next_milestone = None next_milestone = None
for milestone in self.loss_milestones: for milestone in self.loss_milestones:
if d1_loss > milestone: if loss_value < milestone:
next_milestone = milestone next_milestone = milestone
break break
if next_milestone: if next_milestone:
# tfw can do simple calculus but not basic algebra in my head # tfw can do simple calculus but not basic algebra in my head
est_its = (next_milestone - d1_loss) / deriv est_its = (next_milestone - loss_value) / deriv
if est_its >= 0: if est_its >= 0:
self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its') self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
else: else:
est_loss = inst_deriv * (self.its - self.it) + d1_loss est_loss = inst_deriv * (self.its - self.it) + loss_value
if est_loss >= 0: if est_loss >= 0:
self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}') self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}')