Accumulate loss & grad_norm metrics from all entities within a distributed graph
This commit is contained in:
parent
79e5692388
commit
3ff878ae85
|
@ -4,6 +4,7 @@ import os
|
|||
from time import time
|
||||
|
||||
import torch
|
||||
from torch import distributed
|
||||
from torch.nn.parallel import DataParallel
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -320,6 +321,11 @@ class ExtensibleTrainer(BaseModel):
|
|||
pgroups = {f'{name}_all_parameters': list(model.parameters())}
|
||||
for name in pgroups.keys():
|
||||
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
|
||||
if distributed.is_available() and distributed.is_initialized():
|
||||
# Gather the metric from all devices if in a distributed setting.
|
||||
distributed.all_reduce(grad_norms[name], op=distributed.ReduceOp.SUM)
|
||||
grad_norms[name] /= distributed.get_world_size()
|
||||
grad_norms[name] = grad_norms[name].cpu()
|
||||
|
||||
self.consume_gradients(state, step, it)
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import torch
|
||||
|
||||
# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting.
|
||||
from torch import distributed
|
||||
|
||||
|
||||
class LossAccumulator:
|
||||
def __init__(self, buffer_sz=50):
|
||||
self.buffer_sz = buffer_sz
|
||||
|
@ -19,6 +22,10 @@ class LossAccumulator:
|
|||
if '_histogram' in name:
|
||||
buf[i] = torch.flatten(tensor.detach().cpu())
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
if distributed.is_available() and distributed.is_initialized():
|
||||
# Gather the metric from all devices before storing it locally.
|
||||
distributed.all_reduce(tensor, op=distributed.ReduceOp.SUM)
|
||||
tensor /= distributed.get_world_size()
|
||||
buf[i] = tensor.detach().cpu()
|
||||
else:
|
||||
buf[i] = tensor
|
||||
|
|
Loading…
Reference in New Issue
Block a user