forked from mrq/DL-Art-School
Move log consensus to train for efficiency
This commit is contained in:
parent
ce6dfdf255
commit
e1052a5e32
codes
|
@ -195,11 +195,14 @@ class Trainer:
|
||||||
#### log
|
#### log
|
||||||
if self.dataset_debugger is not None:
|
if self.dataset_debugger is not None:
|
||||||
self.dataset_debugger.update(train_data)
|
self.dataset_debugger.update(train_data)
|
||||||
|
if will_log:
|
||||||
|
# Must be run by all instances to gather consensus.
|
||||||
|
current_model_logs = self.model.get_current_log(self.current_step)
|
||||||
if will_log and self.rank <= 0:
|
if will_log and self.rank <= 0:
|
||||||
logs = {'step': self.current_step,
|
logs = {'step': self.current_step,
|
||||||
'samples': self.total_training_data_encountered,
|
'samples': self.total_training_data_encountered,
|
||||||
'megasamples': self.total_training_data_encountered / 1000000}
|
'megasamples': self.total_training_data_encountered / 1000000}
|
||||||
logs.update(self.model.get_current_log(self.current_step))
|
logs.update(current_model_logs)
|
||||||
if self.dataset_debugger is not None:
|
if self.dataset_debugger is not None:
|
||||||
logs.update(self.dataset_debugger.get_debugging_map())
|
logs.update(self.dataset_debugger.get_debugging_map())
|
||||||
logs.update(gradient_norms_dict)
|
logs.update(gradient_norms_dict)
|
||||||
|
|
|
@ -447,6 +447,17 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
# The batch size optimizer also outputs loggable data.
|
# The batch size optimizer also outputs loggable data.
|
||||||
log.update(self.batch_size_optimizer.get_statistics())
|
log.update(self.batch_size_optimizer.get_statistics())
|
||||||
|
|
||||||
|
# In distributed mode, get agreement on all single tensors.
|
||||||
|
if distributed.is_available() and distributed.is_initialized():
|
||||||
|
for k, v in log.items():
|
||||||
|
if not isinstance(v, torch.Tensor):
|
||||||
|
continue
|
||||||
|
if v.len(v.shape) != 1 or v.dtype != torch.float:
|
||||||
|
continue
|
||||||
|
distributed.all_reduce(v, op=distributed.ReduceOp.SUM)
|
||||||
|
log[k] = v / distributed.get_world_size()
|
||||||
|
|
||||||
return log
|
return log
|
||||||
|
|
||||||
def get_current_visuals(self, need_GT=True):
|
def get_current_visuals(self, need_GT=True):
|
||||||
|
|
|
@ -22,10 +22,6 @@ class LossAccumulator:
|
||||||
if '_histogram' in name:
|
if '_histogram' in name:
|
||||||
buf[i] = torch.flatten(tensor.detach().cpu())
|
buf[i] = torch.flatten(tensor.detach().cpu())
|
||||||
elif isinstance(tensor, torch.Tensor):
|
elif isinstance(tensor, torch.Tensor):
|
||||||
if distributed.is_available() and distributed.is_initialized():
|
|
||||||
# Gather the metric from all devices before storing it locally.
|
|
||||||
distributed.all_reduce(tensor, op=distributed.ReduceOp.SUM)
|
|
||||||
tensor /= distributed.get_world_size()
|
|
||||||
buf[i] = tensor.detach().cpu()
|
buf[i] = tensor.detach().cpu()
|
||||||
else:
|
else:
|
||||||
buf[i] = tensor
|
buf[i] = tensor
|
||||||
|
|
Loading…
Reference in New Issue
Block a user