diff --git a/modules/dlas b/modules/dlas index 3fdf2a6..b253da6 160000 --- a/modules/dlas +++ b/modules/dlas @@ -1 +1 @@ -Subproject commit 3fdf2a63aaf901f16763fa632269b823915199f4 +Subproject commit b253da6e353f0170c3eb60fe299c41d2fa21db50 diff --git a/src/utils.py b/src/utils.py index a3a928d..3c375a1 100755 --- a/src/utils.py +++ b/src/utils.py @@ -937,15 +937,19 @@ class TrainingState(): should_return = True else: - # INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028} - if line.find('INFO: Training Metrics:') >= 0: + data = None + if line.find('INFO: Saving models and training states.') >= 0: + self.checkpoint += 1 + message = f"[{self.checkpoint}/{self.checkpoints}] Saving checkpoint..." + percent = self.checkpoint / self.checkpoints + + self.cleanup_old(keep=keep_x_past_checkpoints) + elif line.find('INFO: Training Metrics:') >= 0: data = json.loads(line.split("INFO: Training Metrics:")[-1]) data['mode'] = "training" elif line.find('INFO: Validation Metrics:') >= 0: data = json.loads(line.split("INFO: Validation Metrics:")[-1]) data['mode'] = "validation" - else: - data = None if data is not None: if ': nan' in line and not self.nan_detected: