diff --git a/vall_e/config.py b/vall_e/config.py index f7cfc7b..1324762 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -469,6 +469,7 @@ class Trainer: weight_dtype: str = "float16" amp: bool = False + ddp: bool = False load_webui: bool = False no_logger: bool = False diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 6ee4860..faf927a 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -1,6 +1,6 @@ from ..config import cfg -from ..utils.distributed import fix_unset_envs +from ..utils.distributed import fix_unset_envs, ddp_model fix_unset_envs() if cfg.trainer.backend == "deepspeed": @@ -38,6 +38,7 @@ def load_engines(training=True): dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype amp = cfg.inference.amp if inferencing else cfg.trainer.amp loads_state_dict = cfg.trainer.load_state_dict or inferencing + ddp = cfg.trainer.ddp engine_class = _Engine if backend == "local" or inferencing else Engine @@ -117,10 +118,14 @@ def load_engines(training=True): model.load_state_dict(state, strict=cfg.trainer.strict_loading) + _cfg = model._cfg + # wrap if DDP is requested + if ddp: + model = ddp_model(model) # deepspeed inferencing - if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"): + elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"): engine_class = _Engine model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module @@ -130,9 +135,10 @@ def load_engines(training=True): optimizer=optimizer, lr_scheduler=lr_scheduler, - _cfg=model._cfg, + _cfg=_cfg, stats=stats ) + engines = Engines(engines) engines.setup() diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 99695e9..3705b3c 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -372,15 +372,6 @@ def example_usage(): 'config': cfg.model } - """ - kwargs = { - 'n_tokens': 1024, - 'd_model': 256, - 'n_heads': 4, - 'n_layers': 12, - 'n_experts': 8, - } - """ """ try: @@ -390,7 +381,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 500 + steps = 100 optimizer = ml.Prodigy(model.parameters(), lr=1.0) #optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2) #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) diff --git a/vall_e/train.py b/vall_e/train.py index 75b2274..70506ba 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -183,4 +183,5 @@ def train(): ) if __name__ == "__main__": + # to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"` train() diff --git a/vall_e/utils/distributed.py b/vall_e/utils/distributed.py index 889de53..88e534a 100755 --- a/vall_e/utils/distributed.py +++ b/vall_e/utils/distributed.py @@ -8,6 +8,10 @@ import socket from functools import cache, wraps from typing import Callable +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + def get_free_port(): sock = socket.socket() sock.bind(("", 0)) @@ -17,6 +21,7 @@ def get_free_port(): _distributed_initialized = False def init_distributed( fn, *args, **kwargs ): #print("Initializing distributed...") + torch.cuda.set_device(local_rank()) fn(*args, **kwargs) _distributed_initialized = True @@ -45,7 +50,6 @@ def fix_unset_envs(): def local_rank(): return int(os.getenv("LOCAL_RANK", 0)) - def global_rank(): return int(os.getenv("RANK", 0)) @@ -90,4 +94,7 @@ def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable: if fn is None: return wrapper - return wrapper(fn) \ No newline at end of file + return wrapper(fn) + +def ddp_model(model): + return DDP(model.to(device='cuda'), [local_rank()]) diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 9d8af47..76c409c 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -1,7 +1,9 @@ from contextlib import contextmanager +import math import torch import torch.nn.functional as F + from ..config import cfg Embedding = torch.nn.Embedding @@ -99,12 +101,13 @@ if cfg.optimizations.injects and cfg.optimizations.bitsandbytes: torch.optim.SGD = SGD # disgusting kludge, but it works (just realized BitNet has its own replacement routine) -def replace_linear( model ): +def replace_linear( model, verbose=False ): bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet - klass = Linear device = next(model.parameters()).device linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)] + klass = Linear + for *parent, k in linears: name = '.'.join(parent) @@ -112,6 +115,9 @@ def replace_linear( model ): # copy parameters m = getattr( model.get_submodule(name), k ) + if isinstance(m, klass): + continue + in_features = m.in_features out_features = m.out_features bias = m.bias is not None @@ -123,6 +129,9 @@ def replace_linear( model ): model.get_submodule(name), k, klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype) ) + + if verbose: + print(f"Replacing {name}.{k} to", klass) return model