Move log consensus to train for efficiency

This commit is contained in:
James Betker 2022-03-04 13:41:32 -07:00
parent ce6dfdf255
commit e1052a5e32
3 changed files with 15 additions and 5 deletions

View File

@ -195,11 +195,14 @@ class Trainer:
#### log
if self.dataset_debugger is not None:
self.dataset_debugger.update(train_data)
if will_log:
# Must be run by all instances to gather consensus.
current_model_logs = self.model.get_current_log(self.current_step)
if will_log and self.rank <= 0:
logs = {'step': self.current_step,
'samples': self.total_training_data_encountered,
'megasamples': self.total_training_data_encountered / 1000000}
logs.update(self.model.get_current_log(self.current_step))
logs.update(current_model_logs)
if self.dataset_debugger is not None:
logs.update(self.dataset_debugger.get_debugging_map())
logs.update(gradient_norms_dict)

View File

@ -447,6 +447,17 @@ class ExtensibleTrainer(BaseModel):
# The batch size optimizer also outputs loggable data.
log.update(self.batch_size_optimizer.get_statistics())
# In distributed mode, get agreement on all single tensors.
if distributed.is_available() and distributed.is_initialized():
for k, v in log.items():
if not isinstance(v, torch.Tensor):
continue
if v.len(v.shape) != 1 or v.dtype != torch.float:
continue
distributed.all_reduce(v, op=distributed.ReduceOp.SUM)
log[k] = v / distributed.get_world_size()
return log
def get_current_visuals(self, need_GT=True):

View File

@ -22,10 +22,6 @@ class LossAccumulator:
if '_histogram' in name:
buf[i] = torch.flatten(tensor.detach().cpu())
elif isinstance(tensor, torch.Tensor):
if distributed.is_available() and distributed.is_initialized():
# Gather the metric from all devices before storing it locally.
distributed.all_reduce(tensor, op=distributed.ReduceOp.SUM)
tensor /= distributed.get_world_size()
buf[i] = tensor.detach().cpu()
else:
buf[i] = tensor