some fixes for the local framework
This commit is contained in:
parent
5970f254e3
commit
d89568a96e
|
@ -375,6 +375,14 @@ class Trainer:
|
|||
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||
|
||||
@cached_property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
return torch.float16
|
||||
if cfg.trainer.weight_dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
|
||||
@dataclass()
|
||||
class Inference:
|
||||
|
|
3
vall_e/engines/__init__.py
Normal file → Executable file
3
vall_e/engines/__init__.py
Normal file → Executable file
|
@ -1,5 +1,8 @@
|
|||
from ..config import cfg
|
||||
|
||||
from ..utils.distributed import fix_unset_envs
|
||||
fix_unset_envs()
|
||||
|
||||
if cfg.trainer.backend == "deepspeed":
|
||||
from .deepspeed import Engine
|
||||
elif cfg.trainer.backend == "local":
|
||||
|
|
|
@ -28,6 +28,7 @@ def default_feeder(engine, batch):
|
|||
|
||||
from ..config import cfg
|
||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||
from ..utils.distributed import init_distributed, distributed_initialized
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
@ -43,10 +44,13 @@ from .base import TrainFeeder
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
if not distributed_initialized() and cfg.trainer.backend == "local":
|
||||
init_distributed(torch.distributed.init_process_group)
|
||||
|
||||
# A very naive engine implementation using barebones PyTorch
|
||||
class Engine():
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.module = kwargs['model'].to(cfg.device)
|
||||
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
|
||||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||
|
||||
|
@ -93,6 +97,8 @@ class Engine():
|
|||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||
}, save_path)
|
||||
|
||||
open(save_dir / "latest", 'w').write( tag )
|
||||
|
||||
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
|
||||
if tag is None:
|
||||
tag_path = load_dir / "latest"
|
||||
|
@ -105,8 +111,8 @@ class Engine():
|
|||
return
|
||||
|
||||
state = torch.load(load_path)
|
||||
self.global_step = state['global_step']
|
||||
self.micro_step = state['micro_step']
|
||||
self.global_steps = state['global_step']
|
||||
self.micro_steps = state['micro_step']
|
||||
self.module.load_state_dict(state['module'])
|
||||
|
||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||
|
|
11
vall_e/engines/deepspeed.py
Normal file → Executable file
11
vall_e/engines/deepspeed.py
Normal file → Executable file
|
@ -21,14 +21,13 @@ from .base import TrainFeeder
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed
|
||||
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed as init_deepspeed_dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
|
||||
initialized_dist = False
|
||||
if not initialized_dist:
|
||||
initialized_dist = True
|
||||
init_distributed()
|
||||
from ..utils.distributed import init_distributed, distributed_initialized
|
||||
|
||||
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
||||
init_distributed(init_deepspeed_dist)
|
||||
|
||||
class Engine(DeepSpeedEngine):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
@ -134,7 +134,7 @@ def run_eval(engines, eval_name, dl):
|
|||
|
||||
iteration = engines.global_step
|
||||
engines_stats['it'] = iteration
|
||||
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
||||
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||
|
||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
|
||||
|
|
|
@ -14,6 +14,14 @@ def get_free_port():
|
|||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
_distributed_initialized = False
|
||||
def init_distributed( fn ):
|
||||
fn()
|
||||
_distributed_initialized = True
|
||||
|
||||
def distributed_initialized():
|
||||
return _distributed_initialized
|
||||
|
||||
@cache
|
||||
def fix_unset_envs():
|
||||
envs = dict(
|
||||
|
|
|
@ -18,8 +18,8 @@ from tqdm import tqdm
|
|||
from typing import Protocol
|
||||
|
||||
from ..config import cfg
|
||||
from .distributed import init_distributed, distributed_initialized
|
||||
from .distributed import (
|
||||
fix_unset_envs,
|
||||
global_leader_only,
|
||||
global_rank,
|
||||
is_global_leader,
|
||||
|
@ -112,18 +112,46 @@ def _get_stdin_selector():
|
|||
return selector
|
||||
|
||||
|
||||
if os.name == "nt":
|
||||
import msvcrt
|
||||
_buffer = []
|
||||
|
||||
def _non_blocking_input():
|
||||
global _command
|
||||
global _buffer
|
||||
l = [""]
|
||||
if is_global_leader():
|
||||
|
||||
def _windows():
|
||||
global _buffer
|
||||
|
||||
if msvcrt.kbhit():
|
||||
s: str = msvcrt.getch().decode('utf-8')
|
||||
if s == '\r':
|
||||
s = "".join(_buffer)
|
||||
_buffer = []
|
||||
return s
|
||||
|
||||
_buffer.append(s)
|
||||
return ""
|
||||
|
||||
def _linux():
|
||||
s = ""
|
||||
selector = _get_stdin_selector()
|
||||
events = selector.select(timeout=0)
|
||||
for key, _ in events:
|
||||
s: str = key.fileobj.readline().strip()
|
||||
return s
|
||||
|
||||
if is_global_leader():
|
||||
s = _windows() if os.name == 'nt' else _linux()
|
||||
|
||||
if s != "":
|
||||
_logger.info(f'Get stdin "{s}".')
|
||||
|
||||
l[0] = s
|
||||
broadcast_object_list(l, src=0)
|
||||
|
||||
if distributed_initialized():
|
||||
broadcast_object_list(l, src=0)
|
||||
_command = l[0]
|
||||
return _command
|
||||
|
||||
|
@ -152,8 +180,6 @@ def train(
|
|||
eval_fn: EvalFn = lambda x: ...,
|
||||
logger: Logger = logger,
|
||||
):
|
||||
fix_unset_envs()
|
||||
|
||||
engines = load_engines()
|
||||
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user