Disable loss ETA for now until I fix it
This commit is contained in:
parent
51f6c347fe
commit
9594a960b0
28
src/utils.py
28
src/utils.py
|
@ -696,7 +696,7 @@ class TrainingState():
|
||||||
|
|
||||||
epoch = self.epoch + (self.step / self.steps)
|
epoch = self.epoch + (self.step / self.steps)
|
||||||
if 'lr' in self.info:
|
if 'lr' in self.info:
|
||||||
self.statistics['lr'].append({'epoch': epoch, 'value': self.info['lr'], 'type': 'learning_rate'})
|
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
|
||||||
|
|
||||||
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
|
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
|
||||||
if k not in self.info:
|
if k not in self.info:
|
||||||
|
@ -705,7 +705,7 @@ class TrainingState():
|
||||||
if k == "loss_gpt_total":
|
if k == "loss_gpt_total":
|
||||||
self.losses.append( self.statistics['loss'][-1] )
|
self.losses.append( self.statistics['loss'][-1] )
|
||||||
else:
|
else:
|
||||||
self.statistics['loss'].append({'epoch': epoch, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -728,7 +728,7 @@ class TrainingState():
|
||||||
if len(self.losses) > 0:
|
if len(self.losses) > 0:
|
||||||
self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
|
self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
|
||||||
|
|
||||||
if len(self.losses) >= 2:
|
if False and len(self.losses) >= 2:
|
||||||
deriv = 0
|
deriv = 0
|
||||||
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
|
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
|
||||||
loss_value = self.losses[-1]["value"]
|
loss_value = self.losses[-1]["value"]
|
||||||
|
@ -738,8 +738,8 @@ class TrainingState():
|
||||||
d2_loss = self.losses[accum_length-i-2]["value"]
|
d2_loss = self.losses[accum_length-i-2]["value"]
|
||||||
dloss = (d2_loss - d1_loss)
|
dloss = (d2_loss - d1_loss)
|
||||||
|
|
||||||
d1_step = self.losses[accum_length-i-1]["epoch"]
|
d1_step = self.losses[accum_length-i-1]["it"]
|
||||||
d2_step = self.losses[accum_length-i-2]["epoch"]
|
d2_step = self.losses[accum_length-i-2]["it"]
|
||||||
dstep = (d2_step - d1_step)
|
dstep = (d2_step - d1_step)
|
||||||
|
|
||||||
if dstep == 0:
|
if dstep == 0:
|
||||||
|
@ -750,16 +750,21 @@ class TrainingState():
|
||||||
|
|
||||||
deriv = deriv / accum_length
|
deriv = deriv / accum_length
|
||||||
|
|
||||||
|
print("Deriv: ", deriv)
|
||||||
|
|
||||||
if deriv != 0: # dloss < 0:
|
if deriv != 0: # dloss < 0:
|
||||||
next_milestone = None
|
next_milestone = None
|
||||||
for milestone in self.loss_milestones:
|
for milestone in self.loss_milestones:
|
||||||
if loss_value > milestone:
|
if loss_value > milestone:
|
||||||
next_milestone = milestone
|
next_milestone = milestone
|
||||||
break
|
break
|
||||||
|
|
||||||
|
print(f"Loss value: {loss_value} | Next milestone: {next_milestone} | Distance: {loss_value - next_milestone}")
|
||||||
|
|
||||||
if next_milestone:
|
if next_milestone:
|
||||||
# tfw can do simple calculus but not basic algebra in my head
|
# tfw can do simple calculus but not basic algebra in my head
|
||||||
est_its = (next_milestone - loss_value) / deriv
|
est_its = (next_milestone - loss_value) / deriv * 100
|
||||||
|
print(f"Estimated: {est_its}")
|
||||||
if est_its >= 0:
|
if est_its >= 0:
|
||||||
self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
|
self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
|
||||||
else:
|
else:
|
||||||
|
@ -769,7 +774,7 @@ class TrainingState():
|
||||||
|
|
||||||
self.metrics['loss'] = ", ".join(self.metrics['loss'])
|
self.metrics['loss'] = ", ".join(self.metrics['loss'])
|
||||||
|
|
||||||
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]"
|
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]"
|
||||||
if self.nan_detected:
|
if self.nan_detected:
|
||||||
message = f"[!NaN DETECTED! {self.nan_detected}] {message}"
|
message = f"[!NaN DETECTED! {self.nan_detected}] {message}"
|
||||||
|
|
||||||
|
@ -814,6 +819,7 @@ class TrainingState():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.parse_metrics(data)
|
self.parse_metrics(data)
|
||||||
|
print(self.get_status())
|
||||||
# print(f"Iterations Left: {self.its - self.it} | Elapsed Time: {self.it_rates} | Time Remaining: {self.eta} | Message: {self.get_status()}")
|
# print(f"Iterations Left: {self.its - self.it} | Elapsed Time: {self.it_rates} | Time Remaining: {self.eta} | Message: {self.get_status()}")
|
||||||
|
|
||||||
self.last_info_check_at = highest_step
|
self.last_info_check_at = highest_step
|
||||||
|
@ -959,17 +965,17 @@ def update_training_dataplot(config_path=None):
|
||||||
print(message)
|
print(message)
|
||||||
|
|
||||||
if len(training_state.statistics['loss']) > 0:
|
if len(training_state.statistics['loss']) > 0:
|
||||||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'value', 'type'], width=500, height=350,)
|
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||||
if len(training_state.statistics['lr']) > 0:
|
if len(training_state.statistics['lr']) > 0:
|
||||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'value', 'type'], width=500, height=350,)
|
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||||
del training_state
|
del training_state
|
||||||
training_state = None
|
training_state = None
|
||||||
else:
|
else:
|
||||||
training_state.load_statistics()
|
training_state.load_statistics()
|
||||||
if len(training_state.statistics['loss']) > 0:
|
if len(training_state.statistics['loss']) > 0:
|
||||||
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'value', 'type'], width=500, height=350,)
|
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||||
if len(training_state.statistics['lr']) > 0:
|
if len(training_state.statistics['lr']) > 0:
|
||||||
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'value', 'type'], width=500, height=350,)
|
lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'], width=500, height=350,)
|
||||||
|
|
||||||
return (losses, lrs)
|
return (losses, lrs)
|
||||||
|
|
||||||
|
|
|
@ -510,7 +510,7 @@ def setup_gradio():
|
||||||
y="value",
|
y="value",
|
||||||
title="Loss Metrics",
|
title="Loss Metrics",
|
||||||
color="type",
|
color="type",
|
||||||
tooltip=['epoch', 'value', 'type'],
|
tooltip=['epoch', 'it', 'value', 'type'],
|
||||||
width=500,
|
width=500,
|
||||||
height=350,
|
height=350,
|
||||||
)
|
)
|
||||||
|
@ -519,7 +519,7 @@ def setup_gradio():
|
||||||
y="value",
|
y="value",
|
||||||
title="Learning Rate",
|
title="Learning Rate",
|
||||||
color="type",
|
color="type",
|
||||||
tooltip=['epoch', 'value', 'type'],
|
tooltip=['epoch', 'it', 'value', 'type'],
|
||||||
width=500,
|
width=500,
|
||||||
height=350,
|
height=350,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user