simple DDP wrapper (for my NVlink test)

This commit is contained in:
mrq 2024-05-04 11:48:26 -05:00
parent 783db3d2c5
commit c494894261
6 changed files with 32 additions and 17 deletions

View File

@ -469,6 +469,7 @@ class Trainer:
weight_dtype: str = "float16" weight_dtype: str = "float16"
amp: bool = False amp: bool = False
ddp: bool = False
load_webui: bool = False load_webui: bool = False
no_logger: bool = False no_logger: bool = False

View File

@ -1,6 +1,6 @@
from ..config import cfg from ..config import cfg
from ..utils.distributed import fix_unset_envs from ..utils.distributed import fix_unset_envs, ddp_model
fix_unset_envs() fix_unset_envs()
if cfg.trainer.backend == "deepspeed": if cfg.trainer.backend == "deepspeed":
@ -38,6 +38,7 @@ def load_engines(training=True):
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp amp = cfg.inference.amp if inferencing else cfg.trainer.amp
loads_state_dict = cfg.trainer.load_state_dict or inferencing loads_state_dict = cfg.trainer.load_state_dict or inferencing
ddp = cfg.trainer.ddp
engine_class = _Engine if backend == "local" or inferencing else Engine engine_class = _Engine if backend == "local" or inferencing else Engine
@ -117,10 +118,14 @@ def load_engines(training=True):
model.load_state_dict(state, strict=cfg.trainer.strict_loading) model.load_state_dict(state, strict=cfg.trainer.strict_loading)
_cfg = model._cfg
# wrap if DDP is requested
if ddp:
model = ddp_model(model)
# deepspeed inferencing # deepspeed inferencing
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"): elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
engine_class = _Engine engine_class = _Engine
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
@ -130,9 +135,10 @@ def load_engines(training=True):
optimizer=optimizer, optimizer=optimizer,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
_cfg=model._cfg, _cfg=_cfg,
stats=stats stats=stats
) )
engines = Engines(engines) engines = Engines(engines)
engines.setup() engines.setup()

View File

@ -372,15 +372,6 @@ def example_usage():
'config': cfg.model 'config': cfg.model
} }
"""
kwargs = {
'n_tokens': 1024,
'd_model': 256,
'n_heads': 4,
'n_layers': 12,
'n_experts': 8,
}
"""
""" """
try: try:
@ -390,7 +381,7 @@ def example_usage():
""" """
model = AR_NAR(**kwargs).to(device) model = AR_NAR(**kwargs).to(device)
steps = 500 steps = 100
optimizer = ml.Prodigy(model.parameters(), lr=1.0) optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2) #optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2)
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)

View File

@ -183,4 +183,5 @@ def train():
) )
if __name__ == "__main__": if __name__ == "__main__":
# to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
train() train()

View File

@ -8,6 +8,10 @@ import socket
from functools import cache, wraps from functools import cache, wraps
from typing import Callable from typing import Callable
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def get_free_port(): def get_free_port():
sock = socket.socket() sock = socket.socket()
sock.bind(("", 0)) sock.bind(("", 0))
@ -17,6 +21,7 @@ def get_free_port():
_distributed_initialized = False _distributed_initialized = False
def init_distributed( fn, *args, **kwargs ): def init_distributed( fn, *args, **kwargs ):
#print("Initializing distributed...") #print("Initializing distributed...")
torch.cuda.set_device(local_rank())
fn(*args, **kwargs) fn(*args, **kwargs)
_distributed_initialized = True _distributed_initialized = True
@ -45,7 +50,6 @@ def fix_unset_envs():
def local_rank(): def local_rank():
return int(os.getenv("LOCAL_RANK", 0)) return int(os.getenv("LOCAL_RANK", 0))
def global_rank(): def global_rank():
return int(os.getenv("RANK", 0)) return int(os.getenv("RANK", 0))
@ -90,4 +94,7 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
if fn is None: if fn is None:
return wrapper return wrapper
return wrapper(fn) return wrapper(fn)
def ddp_model(model):
return DDP(model.to(device='cuda'), [local_rank()])

View File

@ -1,7 +1,9 @@
from contextlib import contextmanager from contextlib import contextmanager
import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ..config import cfg from ..config import cfg
Embedding = torch.nn.Embedding Embedding = torch.nn.Embedding
@ -99,12 +101,13 @@ if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
torch.optim.SGD = SGD torch.optim.SGD = SGD
# disgusting kludge, but it works (just realized BitNet has its own replacement routine) # disgusting kludge, but it works (just realized BitNet has its own replacement routine)
def replace_linear( model ): def replace_linear( model, verbose=False ):
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
klass = Linear
device = next(model.parameters()).device device = next(model.parameters()).device
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)] linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
klass = Linear
for *parent, k in linears: for *parent, k in linears:
name = '.'.join(parent) name = '.'.join(parent)
@ -112,6 +115,9 @@ def replace_linear( model ):
# copy parameters # copy parameters
m = getattr( model.get_submodule(name), k ) m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
in_features = m.in_features in_features = m.in_features
out_features = m.out_features out_features = m.out_features
bias = m.bias is not None bias = m.bias is not None
@ -123,6 +129,9 @@ def replace_linear( model ):
model.get_submodule(name), k, model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype) klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
) )
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model return model