oops
This commit is contained in:
parent
b989123bd4
commit
5a41db978e
12
src/utils.py
12
src/utils.py
|
@ -546,9 +546,9 @@ class TrainingState():
|
|||
|
||||
for k in infos:
|
||||
if 'loss_gpt_total' in infos[k]:
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "text_ce" })
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "mel_ce" })
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "gpt_total" })
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "loss_text_ce" })
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "loss_mel_ce" })
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "loss_gpt_total" })
|
||||
|
||||
def cleanup_old(self, keep=2):
|
||||
if keep <= 0:
|
||||
|
@ -663,9 +663,9 @@ class TrainingState():
|
|||
if 'loss_gpt_total' in self.info:
|
||||
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
||||
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "text_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "mel_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "gpt_total" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" })
|
||||
|
||||
should_return = True
|
||||
elif line.find('Saving models and training states') >= 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user