forked from mrq/ai-voice-cloning
oops
This commit is contained in:
parent
8094401a6d
commit
b2e89d8da3
16
src/utils.py
16
src/utils.py
|
@ -817,13 +817,15 @@ class TrainingState():
|
||||||
# """riemann sum""" but not really as this is for derivatives and not integrals
|
# """riemann sum""" but not really as this is for derivatives and not integrals
|
||||||
deriv = 0
|
deriv = 0
|
||||||
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
|
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
|
||||||
|
loss_value = self.losses[-1]["value"]
|
||||||
|
|
||||||
for i in range(accum_length):
|
for i in range(accum_length):
|
||||||
d1_loss = self.losses[-i-1]["value"]
|
d1_loss = self.losses[accum_length-i-1]["value"]
|
||||||
d2_loss = self.losses[-i-2]["value"]
|
d2_loss = self.losses[accum_length-i-2]["value"]
|
||||||
dloss = (d2_loss - d1_loss)
|
dloss = (d2_loss - d1_loss)
|
||||||
|
|
||||||
d1_step = self.losses[-i-1]["step"]
|
d1_step = self.losses[accum_length-i-1]["step"]
|
||||||
d2_step = self.losses[-i-2]["step"]
|
d2_step = self.losses[accum_length-i-2]["step"]
|
||||||
dstep = (d2_step - d1_step)
|
dstep = (d2_step - d1_step)
|
||||||
|
|
||||||
if dstep == 0:
|
if dstep == 0:
|
||||||
|
@ -837,17 +839,17 @@ class TrainingState():
|
||||||
if deriv != 0: # dloss < 0:
|
if deriv != 0: # dloss < 0:
|
||||||
next_milestone = None
|
next_milestone = None
|
||||||
for milestone in self.loss_milestones:
|
for milestone in self.loss_milestones:
|
||||||
if d1_loss > milestone:
|
if loss_value < milestone:
|
||||||
next_milestone = milestone
|
next_milestone = milestone
|
||||||
break
|
break
|
||||||
|
|
||||||
if next_milestone:
|
if next_milestone:
|
||||||
# tfw can do simple calculus but not basic algebra in my head
|
# tfw can do simple calculus but not basic algebra in my head
|
||||||
est_its = (next_milestone - d1_loss) / deriv
|
est_its = (next_milestone - loss_value) / deriv
|
||||||
if est_its >= 0:
|
if est_its >= 0:
|
||||||
self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
|
self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
|
||||||
else:
|
else:
|
||||||
est_loss = inst_deriv * (self.its - self.it) + d1_loss
|
est_loss = inst_deriv * (self.its - self.it) + loss_value
|
||||||
if est_loss >= 0:
|
if est_loss >= 0:
|
||||||
self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}')
|
self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user