From e1052a5e3276559227a94fe7fdf5a02472ebf9bf Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 4 Mar 2022 13:41:32 -0700 Subject: [PATCH] Move log consensus to train for efficiency --- codes/train.py | 5 ++++- codes/trainer/ExtensibleTrainer.py | 11 +++++++++++ codes/utils/loss_accumulator.py | 4 ---- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/codes/train.py b/codes/train.py index 08129348..17de6580 100644 --- a/codes/train.py +++ b/codes/train.py @@ -195,11 +195,14 @@ class Trainer: #### log if self.dataset_debugger is not None: self.dataset_debugger.update(train_data) + if will_log: + # Must be run by all instances to gather consensus. + current_model_logs = self.model.get_current_log(self.current_step) if will_log and self.rank <= 0: logs = {'step': self.current_step, 'samples': self.total_training_data_encountered, 'megasamples': self.total_training_data_encountered / 1000000} - logs.update(self.model.get_current_log(self.current_step)) + logs.update(current_model_logs) if self.dataset_debugger is not None: logs.update(self.dataset_debugger.get_debugging_map()) logs.update(gradient_norms_dict) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 7e861a06..1b940f94 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -447,6 +447,17 @@ class ExtensibleTrainer(BaseModel): # The batch size optimizer also outputs loggable data. log.update(self.batch_size_optimizer.get_statistics()) + + # In distributed mode, get agreement on all single tensors. + if distributed.is_available() and distributed.is_initialized(): + for k, v in log.items(): + if not isinstance(v, torch.Tensor): + continue + if v.len(v.shape) != 1 or v.dtype != torch.float: + continue + distributed.all_reduce(v, op=distributed.ReduceOp.SUM) + log[k] = v / distributed.get_world_size() + return log def get_current_visuals(self, need_GT=True): diff --git a/codes/utils/loss_accumulator.py b/codes/utils/loss_accumulator.py index b99bcfc7..08cc6ccc 100644 --- a/codes/utils/loss_accumulator.py +++ b/codes/utils/loss_accumulator.py @@ -22,10 +22,6 @@ class LossAccumulator: if '_histogram' in name: buf[i] = torch.flatten(tensor.detach().cpu()) elif isinstance(tensor, torch.Tensor): - if distributed.is_available() and distributed.is_initialized(): - # Gather the metric from all devices before storing it locally. - distributed.all_reduce(tensor, op=distributed.ReduceOp.SUM) - tensor /= distributed.get_world_size() buf[i] = tensor.detach().cpu() else: buf[i] = tensor