simple DDP wrapper (for my NVlink test)
This commit is contained in:
parent
783db3d2c5
commit
c494894261
|
@ -469,6 +469,7 @@ class Trainer:
|
||||||
|
|
||||||
weight_dtype: str = "float16"
|
weight_dtype: str = "float16"
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
|
ddp: bool = False
|
||||||
|
|
||||||
load_webui: bool = False
|
load_webui: bool = False
|
||||||
no_logger: bool = False
|
no_logger: bool = False
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
|
||||||
from ..utils.distributed import fix_unset_envs
|
from ..utils.distributed import fix_unset_envs, ddp_model
|
||||||
fix_unset_envs()
|
fix_unset_envs()
|
||||||
|
|
||||||
if cfg.trainer.backend == "deepspeed":
|
if cfg.trainer.backend == "deepspeed":
|
||||||
|
@ -38,6 +38,7 @@ def load_engines(training=True):
|
||||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||||
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
||||||
|
ddp = cfg.trainer.ddp
|
||||||
|
|
||||||
engine_class = _Engine if backend == "local" or inferencing else Engine
|
engine_class = _Engine if backend == "local" or inferencing else Engine
|
||||||
|
|
||||||
|
@ -117,10 +118,14 @@ def load_engines(training=True):
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
|
_cfg = model._cfg
|
||||||
|
|
||||||
|
# wrap if DDP is requested
|
||||||
|
if ddp:
|
||||||
|
model = ddp_model(model)
|
||||||
|
|
||||||
# deepspeed inferencing
|
# deepspeed inferencing
|
||||||
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||||
engine_class = _Engine
|
engine_class = _Engine
|
||||||
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||||
|
|
||||||
|
@ -130,10 +135,11 @@ def load_engines(training=True):
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
|
|
||||||
_cfg=model._cfg,
|
_cfg=_cfg,
|
||||||
stats=stats
|
stats=stats
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
engines = Engines(engines)
|
engines = Engines(engines)
|
||||||
engines.setup()
|
engines.setup()
|
||||||
|
|
||||||
|
|
|
@ -372,15 +372,6 @@ def example_usage():
|
||||||
|
|
||||||
'config': cfg.model
|
'config': cfg.model
|
||||||
}
|
}
|
||||||
"""
|
|
||||||
kwargs = {
|
|
||||||
'n_tokens': 1024,
|
|
||||||
'd_model': 256,
|
|
||||||
'n_heads': 4,
|
|
||||||
'n_layers': 12,
|
|
||||||
'n_experts': 8,
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
@ -390,7 +381,7 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 500
|
steps = 100
|
||||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||||
#optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2)
|
#optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2)
|
||||||
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||||
|
|
|
@ -183,4 +183,5 @@ def train():
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
|
||||||
train()
|
train()
|
||||||
|
|
|
@ -8,6 +8,10 @@ import socket
|
||||||
from functools import cache, wraps
|
from functools import cache, wraps
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
def get_free_port():
|
def get_free_port():
|
||||||
sock = socket.socket()
|
sock = socket.socket()
|
||||||
sock.bind(("", 0))
|
sock.bind(("", 0))
|
||||||
|
@ -17,6 +21,7 @@ def get_free_port():
|
||||||
_distributed_initialized = False
|
_distributed_initialized = False
|
||||||
def init_distributed( fn, *args, **kwargs ):
|
def init_distributed( fn, *args, **kwargs ):
|
||||||
#print("Initializing distributed...")
|
#print("Initializing distributed...")
|
||||||
|
torch.cuda.set_device(local_rank())
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
_distributed_initialized = True
|
_distributed_initialized = True
|
||||||
|
|
||||||
|
@ -45,7 +50,6 @@ def fix_unset_envs():
|
||||||
def local_rank():
|
def local_rank():
|
||||||
return int(os.getenv("LOCAL_RANK", 0))
|
return int(os.getenv("LOCAL_RANK", 0))
|
||||||
|
|
||||||
|
|
||||||
def global_rank():
|
def global_rank():
|
||||||
return int(os.getenv("RANK", 0))
|
return int(os.getenv("RANK", 0))
|
||||||
|
|
||||||
|
@ -91,3 +95,6 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return wrapper(fn)
|
return wrapper(fn)
|
||||||
|
|
||||||
|
def ddp_model(model):
|
||||||
|
return DDP(model.to(device='cuda'), [local_rank()])
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
|
||||||
Embedding = torch.nn.Embedding
|
Embedding = torch.nn.Embedding
|
||||||
|
@ -99,12 +101,13 @@ if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
|
||||||
torch.optim.SGD = SGD
|
torch.optim.SGD = SGD
|
||||||
|
|
||||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||||
def replace_linear( model ):
|
def replace_linear( model, verbose=False ):
|
||||||
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
||||||
klass = Linear
|
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
||||||
|
klass = Linear
|
||||||
|
|
||||||
for *parent, k in linears:
|
for *parent, k in linears:
|
||||||
name = '.'.join(parent)
|
name = '.'.join(parent)
|
||||||
|
|
||||||
|
@ -112,6 +115,9 @@ def replace_linear( model ):
|
||||||
# copy parameters
|
# copy parameters
|
||||||
m = getattr( model.get_submodule(name), k )
|
m = getattr( model.get_submodule(name), k )
|
||||||
|
|
||||||
|
if isinstance(m, klass):
|
||||||
|
continue
|
||||||
|
|
||||||
in_features = m.in_features
|
in_features = m.in_features
|
||||||
out_features = m.out_features
|
out_features = m.out_features
|
||||||
bias = m.bias is not None
|
bias = m.bias is not None
|
||||||
|
@ -124,6 +130,9 @@ def replace_linear( model ):
|
||||||
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Replacing {name}.{k} to", klass)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# https://github.com/konstmish/prodigy
|
# https://github.com/konstmish/prodigy
|
||||||
|
|
Loading…
Reference in New Issue
Block a user