Accumulate loss & grad_norm metrics from all entities within a distributed graph

This commit is contained in:
James Betker 2022-03-04 12:01:16 -07:00
parent 79e5692388
commit 3ff878ae85
2 changed files with 13 additions and 0 deletions

View File

@ -4,6 +4,7 @@ import os
from time import time from time import time
import torch import torch
from torch import distributed
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
import torch.nn as nn import torch.nn as nn
@ -320,6 +321,11 @@ class ExtensibleTrainer(BaseModel):
pgroups = {f'{name}_all_parameters': list(model.parameters())} pgroups = {f'{name}_all_parameters': list(model.parameters())}
for name in pgroups.keys(): for name in pgroups.keys():
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2) grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
if distributed.is_available() and distributed.is_initialized():
# Gather the metric from all devices if in a distributed setting.
distributed.all_reduce(grad_norms[name], op=distributed.ReduceOp.SUM)
grad_norms[name] /= distributed.get_world_size()
grad_norms[name] = grad_norms[name].cpu()
self.consume_gradients(state, step, it) self.consume_gradients(state, step, it)

View File

@ -1,6 +1,9 @@
import torch import torch
# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting. # Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting.
from torch import distributed
class LossAccumulator: class LossAccumulator:
def __init__(self, buffer_sz=50): def __init__(self, buffer_sz=50):
self.buffer_sz = buffer_sz self.buffer_sz = buffer_sz
@ -19,6 +22,10 @@ class LossAccumulator:
if '_histogram' in name: if '_histogram' in name:
buf[i] = torch.flatten(tensor.detach().cpu()) buf[i] = torch.flatten(tensor.detach().cpu())
elif isinstance(tensor, torch.Tensor): elif isinstance(tensor, torch.Tensor):
if distributed.is_available() and distributed.is_initialized():
# Gather the metric from all devices before storing it locally.
distributed.all_reduce(tensor, op=distributed.ReduceOp.SUM)
tensor /= distributed.get_world_size()
buf[i] = tensor.detach().cpu() buf[i] = tensor.detach().cpu()
else: else:
buf[i] = tensor buf[i] = tensor