forked from mrq/DL-Art-School
Accumulate loss & grad_norm metrics from all entities within a distributed graph
This commit is contained in:
parent
79e5692388
commit
3ff878ae85
|
@ -4,6 +4,7 @@ import os
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import distributed
|
||||||
from torch.nn.parallel import DataParallel
|
from torch.nn.parallel import DataParallel
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@ -320,6 +321,11 @@ class ExtensibleTrainer(BaseModel):
|
||||||
pgroups = {f'{name}_all_parameters': list(model.parameters())}
|
pgroups = {f'{name}_all_parameters': list(model.parameters())}
|
||||||
for name in pgroups.keys():
|
for name in pgroups.keys():
|
||||||
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
|
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
|
||||||
|
if distributed.is_available() and distributed.is_initialized():
|
||||||
|
# Gather the metric from all devices if in a distributed setting.
|
||||||
|
distributed.all_reduce(grad_norms[name], op=distributed.ReduceOp.SUM)
|
||||||
|
grad_norms[name] /= distributed.get_world_size()
|
||||||
|
grad_norms[name] = grad_norms[name].cpu()
|
||||||
|
|
||||||
self.consume_gradients(state, step, it)
|
self.consume_gradients(state, step, it)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting.
|
# Utility class that stores detached, named losses in a rotating buffer for smooth metric outputting.
|
||||||
|
from torch import distributed
|
||||||
|
|
||||||
|
|
||||||
class LossAccumulator:
|
class LossAccumulator:
|
||||||
def __init__(self, buffer_sz=50):
|
def __init__(self, buffer_sz=50):
|
||||||
self.buffer_sz = buffer_sz
|
self.buffer_sz = buffer_sz
|
||||||
|
@ -19,6 +22,10 @@ class LossAccumulator:
|
||||||
if '_histogram' in name:
|
if '_histogram' in name:
|
||||||
buf[i] = torch.flatten(tensor.detach().cpu())
|
buf[i] = torch.flatten(tensor.detach().cpu())
|
||||||
elif isinstance(tensor, torch.Tensor):
|
elif isinstance(tensor, torch.Tensor):
|
||||||
|
if distributed.is_available() and distributed.is_initialized():
|
||||||
|
# Gather the metric from all devices before storing it locally.
|
||||||
|
distributed.all_reduce(tensor, op=distributed.ReduceOp.SUM)
|
||||||
|
tensor /= distributed.get_world_size()
|
||||||
buf[i] = tensor.detach().cpu()
|
buf[i] = tensor.detach().cpu()
|
||||||
else:
|
else:
|
||||||
buf[i] = tensor
|
buf[i] = tensor
|
||||||
|
|
Loading…
Reference in New Issue
Block a user