vall-e/vall_e/utils/distributed.py

103 lines
2.0 KiB
Python
Executable File

"""
# https://github.com/enhuiz/pytorch-training-utilities
"""
import os
import socket
from functools import cache, wraps
from typing import Callable
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def get_free_port():
sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]
_distributed_initialized = False
def init_distributed( fn, *args, **kwargs ):
torch.cuda.set_device(local_rank())
fn(*args, **kwargs)
_distributed_initialized = True
def distributed_initialized():
return _distributed_initialized
def cleanup_distributed():
dist.barrier()
dist.destroy_process_group()
@cache
def fix_unset_envs():
envs = dict(
RANK="0",
WORLD_SIZE="1",
MASTER_ADDR="localhost",
MASTER_PORT=str(get_free_port()),
LOCAL_RANK="0",
)
for key in envs:
value = os.getenv(key)
if value is not None:
return
for key, value in envs.items():
os.environ[key] = value
def local_rank():
return int(os.getenv("LOCAL_RANK", 0))
def global_rank():
return int(os.getenv("RANK", 0))
def world_size():
return int(os.getenv("WORLD_SIZE", 1))
def is_local_leader():
return local_rank() == 0
def is_global_leader():
return global_rank() == 0
def local_leader_only(fn=None, *, default=None) -> Callable:
def wrapper(fn):
@wraps(fn)
def wrapped(*args, **kwargs):
if is_local_leader():
return fn(*args, **kwargs)
return default
return wrapped
if fn is None:
return wrapper
return wrapper(fn)
def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
def wrapper(fn):
@wraps(fn)
def wrapped(*args, **kwargs):
if is_global_leader():
return fn(*args, **kwargs)
return default
return wrapped
if fn is None:
return wrapper
return wrapper(fn)
def ddp_model(model):
return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True)