diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index e23dfa48..c6b44f7d 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -4,6 +4,7 @@ import os from time import time import torch +from torch import distributed from torch.nn.parallel import DataParallel import torch.nn as nn @@ -320,6 +321,11 @@ class ExtensibleTrainer(BaseModel): pgroups = {f'{name}_all_parameters': list(model.parameters())} for name in pgroups.keys(): grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2) + if distributed.is_available() and distributed.is_initialized(): + # Gather the metric from all devices if in a distributed setting. + distributed.all_reduce(grad_norms[name], op=distributed.ReduceOp.SUM) + grad_norms[name] /= distributed.get_world_size() + grad_norms[name] = grad_norms[name].cpu() self.consume_gradients(state, step, it) diff --git a/codes/utils/loss_accumulator.py b/codes/utils/loss_accumulator.py index 1e78d07d..b99bcfc7 100644 --- a/codes/utils/loss_accumulator.py +++ b/codes/utils/loss_accumulator.py @@ -1,6 +1,9 @@ import torch # Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting. +from torch import distributed + + class LossAccumulator: def __init__(self, buffer_sz=50): self.buffer_sz = buffer_sz @@ -19,6 +22,10 @@ class LossAccumulator: if '_histogram' in name: buf[i] = torch.flatten(tensor.detach().cpu()) elif isinstance(tensor, torch.Tensor): + if distributed.is_available() and distributed.is_initialized(): + # Gather the metric from all devices before storing it locally. + distributed.all_reduce(tensor, op=distributed.ReduceOp.SUM) + tensor /= distributed.get_world_size() buf[i] = tensor.detach().cpu() else: buf[i] = tensor