From d89568a96ec55f5255ea6d2cbc760207a3cdc7b2 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 5 Aug 2023 03:22:15 +0000 Subject: [PATCH] some fixes for the local framework --- vall_e/config.py | 8 ++++++++ vall_e/engines/__init__.py | 3 +++ vall_e/engines/base.py | 12 +++++++++--- vall_e/engines/deepspeed.py | 11 +++++------ vall_e/train.py | 2 +- vall_e/utils/distributed.py | 8 ++++++++ vall_e/utils/trainer.py | 36 +++++++++++++++++++++++++++++++----- 7 files changed, 65 insertions(+), 15 deletions(-) mode change 100644 => 100755 vall_e/engines/__init__.py mode change 100644 => 100755 vall_e/engines/deepspeed.py diff --git a/vall_e/config.py b/vall_e/config.py index fe92beb..9813164 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -375,6 +375,14 @@ class Trainer: deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) + @cached_property + def dtype(self): + if self.weight_dtype == "float16": + return torch.float16 + if cfg.trainer.weight_dtype == "bfloat16": + return torch.bfloat16 + return torch.float32 + @dataclass() class Inference: diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py old mode 100644 new mode 100755 index cf59d93..f0879ec --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -1,5 +1,8 @@ from ..config import cfg +from ..utils.distributed import fix_unset_envs +fix_unset_envs() + if cfg.trainer.backend == "deepspeed": from .deepspeed import Engine elif cfg.trainer.backend == "local": diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 923b565..8b9dc04 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -28,6 +28,7 @@ def default_feeder(engine, batch): from ..config import cfg from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device +from ..utils.distributed import init_distributed, distributed_initialized import logging import time @@ -43,10 +44,13 @@ from .base import TrainFeeder _logger = logging.getLogger(__name__) +if not distributed_initialized() and cfg.trainer.backend == "local": + init_distributed(torch.distributed.init_process_group) + # A very naive engine implementation using barebones PyTorch class Engine(): def __init__(self, *args, **kwargs): - self.module = kwargs['model'].to(cfg.device) + self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype) self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None @@ -93,6 +97,8 @@ class Engine(): "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, }, save_path) + open(save_dir / "latest", 'w').write( tag ) + def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True): if tag is None: tag_path = load_dir / "latest" @@ -105,8 +111,8 @@ class Engine(): return state = torch.load(load_path) - self.global_step = state['global_step'] - self.micro_step = state['micro_step'] + self.global_steps = state['global_step'] + self.micro_steps = state['micro_step'] self.module.load_state_dict(state['module']) load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py old mode 100644 new mode 100755 index 0bb8a7b..8458807 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -21,14 +21,13 @@ from .base import TrainFeeder _logger = logging.getLogger(__name__) -from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed +from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed as init_deepspeed_dist from deepspeed.accelerator import get_accelerator -#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name()) -initialized_dist = False -if not initialized_dist: - initialized_dist = True - init_distributed() +from ..utils.distributed import init_distributed, distributed_initialized + +if not distributed_initialized() and cfg.trainer.backend == "deepspeed": + init_distributed(init_deepspeed_dist) class Engine(DeepSpeedEngine): def __init__(self, *args, **kwargs): diff --git a/vall_e/train.py b/vall_e/train.py index 30a4c66..de6425a 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -134,7 +134,7 @@ def run_eval(engines, eval_name, dl): iteration = engines.global_step engines_stats['it'] = iteration - engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl) + engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") diff --git a/vall_e/utils/distributed.py b/vall_e/utils/distributed.py index 496e9a7..e80b0dd 100755 --- a/vall_e/utils/distributed.py +++ b/vall_e/utils/distributed.py @@ -14,6 +14,14 @@ def get_free_port(): return sock.getsockname()[1] +_distributed_initialized = False +def init_distributed( fn ): + fn() + _distributed_initialized = True + +def distributed_initialized(): + return _distributed_initialized + @cache def fix_unset_envs(): envs = dict( diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index c796025..f9d780d 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -18,8 +18,8 @@ from tqdm import tqdm from typing import Protocol from ..config import cfg +from .distributed import init_distributed, distributed_initialized from .distributed import ( - fix_unset_envs, global_leader_only, global_rank, is_global_leader, @@ -112,18 +112,46 @@ def _get_stdin_selector(): return selector +if os.name == "nt": + import msvcrt + _buffer = [] + def _non_blocking_input(): global _command + global _buffer l = [""] - if is_global_leader(): + + def _windows(): + global _buffer + + if msvcrt.kbhit(): + s: str = msvcrt.getch().decode('utf-8') + if s == '\r': + s = "".join(_buffer) + _buffer = [] + return s + + _buffer.append(s) + return "" + + def _linux(): s = "" selector = _get_stdin_selector() events = selector.select(timeout=0) for key, _ in events: s: str = key.fileobj.readline().strip() + return s + + if is_global_leader(): + s = _windows() if os.name == 'nt' else _linux() + + if s != "": _logger.info(f'Get stdin "{s}".') + l[0] = s - broadcast_object_list(l, src=0) + + if distributed_initialized(): + broadcast_object_list(l, src=0) _command = l[0] return _command @@ -152,8 +180,6 @@ def train( eval_fn: EvalFn = lambda x: ..., logger: Logger = logger, ): - fix_unset_envs() - engines = load_engines() """