simple DDP wrapper (for my NVlink test)

This commit is contained in:
mrq 2024-05-04 11:48:26 -05:00
parent 783db3d2c5
commit c494894261
6 changed files with 32 additions and 17 deletions

View File

@ -469,6 +469,7 @@ class Trainer:
weight_dtype: str = "float16"
amp: bool = False
ddp: bool = False
load_webui: bool = False
no_logger: bool = False

View File

@ -1,6 +1,6 @@
from ..config import cfg
from ..utils.distributed import fix_unset_envs
from ..utils.distributed import fix_unset_envs, ddp_model
fix_unset_envs()
if cfg.trainer.backend == "deepspeed":
@ -38,6 +38,7 @@ def load_engines(training=True):
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
loads_state_dict = cfg.trainer.load_state_dict or inferencing
ddp = cfg.trainer.ddp
engine_class = _Engine if backend == "local" or inferencing else Engine
@ -117,10 +118,14 @@ def load_engines(training=True):
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
_cfg = model._cfg
# wrap if DDP is requested
if ddp:
model = ddp_model(model)
# deepspeed inferencing
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
engine_class = _Engine
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
@ -130,9 +135,10 @@ def load_engines(training=True):
optimizer=optimizer,
lr_scheduler=lr_scheduler,
_cfg=model._cfg,
_cfg=_cfg,
stats=stats
)
engines = Engines(engines)
engines.setup()

View File

@ -372,15 +372,6 @@ def example_usage():
'config': cfg.model
}
"""
kwargs = {
'n_tokens': 1024,
'd_model': 256,
'n_heads': 4,
'n_layers': 12,
'n_experts': 8,
}
"""
"""
try:
@ -390,7 +381,7 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
steps = 500
steps = 100
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2)
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)

View File

@ -183,4 +183,5 @@ def train():
)
if __name__ == "__main__":
# to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
train()

View File

@ -8,6 +8,10 @@ import socket
from functools import cache, wraps
from typing import Callable
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def get_free_port():
sock = socket.socket()
sock.bind(("", 0))
@ -17,6 +21,7 @@ def get_free_port():
_distributed_initialized = False
def init_distributed( fn, *args, **kwargs ):
#print("Initializing distributed...")
torch.cuda.set_device(local_rank())
fn(*args, **kwargs)
_distributed_initialized = True
@ -45,7 +50,6 @@ def fix_unset_envs():
def local_rank():
return int(os.getenv("LOCAL_RANK", 0))
def global_rank():
return int(os.getenv("RANK", 0))
@ -90,4 +94,7 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
if fn is None:
return wrapper
return wrapper(fn)
return wrapper(fn)
def ddp_model(model):
return DDP(model.to(device='cuda'), [local_rank()])

View File

@ -1,7 +1,9 @@
from contextlib import contextmanager
import math
import torch
import torch.nn.functional as F
from ..config import cfg
Embedding = torch.nn.Embedding
@ -99,12 +101,13 @@ if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
torch.optim.SGD = SGD
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
def replace_linear( model ):
def replace_linear( model, verbose=False ):
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
klass = Linear
device = next(model.parameters()).device
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
klass = Linear
for *parent, k in linears:
name = '.'.join(parent)
@ -112,6 +115,9 @@ def replace_linear( model ):
# copy parameters
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
in_features = m.in_features
out_features = m.out_features
bias = m.bias is not None
@ -123,6 +129,9 @@ def replace_linear( model ):
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model