track iteration rate

This commit is contained in:
James Betker 2022-04-04 12:33:25 -06:00
parent 4cdb0169d0
commit 572d137589

View File

@ -186,9 +186,10 @@ class Trainer:
#### training #### training
if self._profile: if self._profile:
print("Update LR: %f" % (time() - _t)) print("Update LR: %f" % (time() - _t))
_t = time() _t = time()
self.model.feed_data(train_data, self.current_step) self.model.feed_data(train_data, self.current_step)
gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log) gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log)
iteration_rate = (time() - _t) / batch_size
if self._profile: if self._profile:
print("Model feed + step: %f" % (time() - _t)) print("Model feed + step: %f" % (time() - _t))
_t = time() _t = time()
@ -202,7 +203,8 @@ class Trainer:
if will_log and self.rank <= 0: if will_log and self.rank <= 0:
logs = {'step': self.current_step, logs = {'step': self.current_step,
'samples': self.total_training_data_encountered, 'samples': self.total_training_data_encountered,
'megasamples': self.total_training_data_encountered / 1000000} 'megasamples': self.total_training_data_encountered / 1000000,
'iteration_rate': iteration_rate}
logs.update(current_model_logs) logs.update(current_model_logs)
if self.dataset_debugger is not None: if self.dataset_debugger is not None:
logs.update(self.dataset_debugger.get_debugging_map()) logs.update(self.dataset_debugger.get_debugging_map())