Accumulate loss & grad_norm metrics from all entities within a distributed graph

This commit is contained in:
James Betker 2022-03-04 12:01:16 -07:00
parent 79e5692388
commit 3ff878ae85
2 changed files with 13 additions and 0 deletions

View File

@ -4,6 +4,7 @@ import os
from time import time
import torch
from torch import distributed
from torch.nn.parallel import DataParallel
import torch.nn as nn
@ -320,6 +321,11 @@ class ExtensibleTrainer(BaseModel):
pgroups = {f'{name}_all_parameters': list(model.parameters())}
for name in pgroups.keys():
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
if distributed.is_available() and distributed.is_initialized():
# Gather the metric from all devices if in a distributed setting.
distributed.all_reduce(grad_norms[name], op=distributed.ReduceOp.SUM)
grad_norms[name] /= distributed.get_world_size()
grad_norms[name] = grad_norms[name].cpu()
self.consume_gradients(state, step, it)

View File

@ -1,6 +1,9 @@
import torch
# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting.
from torch import distributed
class LossAccumulator:
def __init__(self, buffer_sz=50):
self.buffer_sz = buffer_sz
@ -19,6 +22,10 @@ class LossAccumulator:
if '_histogram' in name:
buf[i] = torch.flatten(tensor.detach().cpu())
elif isinstance(tensor, torch.Tensor):
if distributed.is_available() and distributed.is_initialized():
# Gather the metric from all devices before storing it locally.
distributed.all_reduce(tensor, op=distributed.ReduceOp.SUM)
tensor /= distributed.get_world_size()
buf[i] = tensor.detach().cpu()
else:
buf[i] = tensor