track iteration rate

This commit is contained in:
James Betker 2022-04-04 12:33:25 -06:00
parent 4cdb0169d0
commit 572d137589

View File

@ -186,9 +186,10 @@ class Trainer:
#### training
if self._profile:
print("Update LR: %f" % (time() - _t))
_t = time()
_t = time()
self.model.feed_data(train_data, self.current_step)
gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log)
iteration_rate = (time() - _t) / batch_size
if self._profile:
print("Model feed + step: %f" % (time() - _t))
_t = time()
@ -202,7 +203,8 @@ class Trainer:
if will_log and self.rank <= 0:
logs = {'step': self.current_step,
'samples': self.total_training_data_encountered,
'megasamples': self.total_training_data_encountered / 1000000}
'megasamples': self.total_training_data_encountered / 1000000,
'iteration_rate': iteration_rate}
logs.update(current_model_logs)
if self.dataset_debugger is not None:
logs.update(self.dataset_debugger.get_debugging_map())