simple DDP wrapper (for my NVlink test)
This commit is contained in:
parent
783db3d2c5
commit
c494894261
|
@ -469,6 +469,7 @@ class Trainer:
|
|||
|
||||
weight_dtype: str = "float16"
|
||||
amp: bool = False
|
||||
ddp: bool = False
|
||||
|
||||
load_webui: bool = False
|
||||
no_logger: bool = False
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from ..config import cfg
|
||||
|
||||
from ..utils.distributed import fix_unset_envs
|
||||
from ..utils.distributed import fix_unset_envs, ddp_model
|
||||
fix_unset_envs()
|
||||
|
||||
if cfg.trainer.backend == "deepspeed":
|
||||
|
@ -38,6 +38,7 @@ def load_engines(training=True):
|
|||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
||||
ddp = cfg.trainer.ddp
|
||||
|
||||
engine_class = _Engine if backend == "local" or inferencing else Engine
|
||||
|
||||
|
@ -117,10 +118,14 @@ def load_engines(training=True):
|
|||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
_cfg = model._cfg
|
||||
|
||||
# wrap if DDP is requested
|
||||
if ddp:
|
||||
model = ddp_model(model)
|
||||
|
||||
# deepspeed inferencing
|
||||
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||
engine_class = _Engine
|
||||
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||
|
||||
|
@ -130,9 +135,10 @@ def load_engines(training=True):
|
|||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
|
||||
_cfg=model._cfg,
|
||||
_cfg=_cfg,
|
||||
stats=stats
|
||||
)
|
||||
|
||||
|
||||
engines = Engines(engines)
|
||||
engines.setup()
|
||||
|
|
|
@ -372,15 +372,6 @@ def example_usage():
|
|||
|
||||
'config': cfg.model
|
||||
}
|
||||
"""
|
||||
kwargs = {
|
||||
'n_tokens': 1024,
|
||||
'd_model': 256,
|
||||
'n_heads': 4,
|
||||
'n_layers': 12,
|
||||
'n_experts': 8,
|
||||
}
|
||||
"""
|
||||
|
||||
"""
|
||||
try:
|
||||
|
@ -390,7 +381,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 500
|
||||
steps = 100
|
||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||
#optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2)
|
||||
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||
|
|
|
@ -183,4 +183,5 @@ def train():
|
|||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
|
||||
train()
|
||||
|
|
|
@ -8,6 +8,10 @@ import socket
|
|||
from functools import cache, wraps
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def get_free_port():
|
||||
sock = socket.socket()
|
||||
sock.bind(("", 0))
|
||||
|
@ -17,6 +21,7 @@ def get_free_port():
|
|||
_distributed_initialized = False
|
||||
def init_distributed( fn, *args, **kwargs ):
|
||||
#print("Initializing distributed...")
|
||||
torch.cuda.set_device(local_rank())
|
||||
fn(*args, **kwargs)
|
||||
_distributed_initialized = True
|
||||
|
||||
|
@ -45,7 +50,6 @@ def fix_unset_envs():
|
|||
def local_rank():
|
||||
return int(os.getenv("LOCAL_RANK", 0))
|
||||
|
||||
|
||||
def global_rank():
|
||||
return int(os.getenv("RANK", 0))
|
||||
|
||||
|
@ -90,4 +94,7 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
|
|||
if fn is None:
|
||||
return wrapper
|
||||
|
||||
return wrapper(fn)
|
||||
return wrapper(fn)
|
||||
|
||||
def ddp_model(model):
|
||||
return DDP(model.to(device='cuda'), [local_rank()])
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from contextlib import contextmanager
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..config import cfg
|
||||
|
||||
Embedding = torch.nn.Embedding
|
||||
|
@ -99,12 +101,13 @@ if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
|
|||
torch.optim.SGD = SGD
|
||||
|
||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||
def replace_linear( model ):
|
||||
def replace_linear( model, verbose=False ):
|
||||
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
||||
klass = Linear
|
||||
|
||||
device = next(model.parameters()).device
|
||||
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
||||
klass = Linear
|
||||
|
||||
for *parent, k in linears:
|
||||
name = '.'.join(parent)
|
||||
|
||||
|
@ -112,6 +115,9 @@ def replace_linear( model ):
|
|||
# copy parameters
|
||||
m = getattr( model.get_submodule(name), k )
|
||||
|
||||
if isinstance(m, klass):
|
||||
continue
|
||||
|
||||
in_features = m.in_features
|
||||
out_features = m.out_features
|
||||
bias = m.bias is not None
|
||||
|
@ -123,6 +129,9 @@ def replace_linear( model ):
|
|||
model.get_submodule(name), k,
|
||||
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(f"Replacing {name}.{k} to", klass)
|
||||
|
||||
return model
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user