some fixes for the local framework

This commit is contained in:
mrq 2023-08-05 03:22:15 +00:00
parent 5970f254e3
commit d89568a96e
7 changed files with 65 additions and 15 deletions

View File

@ -375,6 +375,14 @@ class Trainer:
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if cfg.trainer.weight_dtype == "bfloat16":
return torch.bfloat16
return torch.float32
@dataclass()
class Inference:

3
vall_e/engines/__init__.py Normal file → Executable file
View File

@ -1,5 +1,8 @@
from ..config import cfg
from ..utils.distributed import fix_unset_envs
fix_unset_envs()
if cfg.trainer.backend == "deepspeed":
from .deepspeed import Engine
elif cfg.trainer.backend == "local":

View File

@ -28,6 +28,7 @@ def default_feeder(engine, batch):
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized
import logging
import time
@ -43,10 +44,13 @@ from .base import TrainFeeder
_logger = logging.getLogger(__name__)
if not distributed_initialized() and cfg.trainer.backend == "local":
init_distributed(torch.distributed.init_process_group)
# A very naive engine implementation using barebones PyTorch
class Engine():
def __init__(self, *args, **kwargs):
self.module = kwargs['model'].to(cfg.device)
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
@ -93,6 +97,8 @@ class Engine():
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
}, save_path)
open(save_dir / "latest", 'w').write( tag )
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
if tag is None:
tag_path = load_dir / "latest"
@ -105,8 +111,8 @@ class Engine():
return
state = torch.load(load_path)
self.global_step = state['global_step']
self.micro_step = state['micro_step']
self.global_steps = state['global_step']
self.micro_steps = state['micro_step']
self.module.load_state_dict(state['module'])
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state

11
vall_e/engines/deepspeed.py Normal file → Executable file
View File

@ -21,14 +21,13 @@ from .base import TrainFeeder
_logger = logging.getLogger(__name__)
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed as init_deepspeed_dist
from deepspeed.accelerator import get_accelerator
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
initialized_dist = False
if not initialized_dist:
initialized_dist = True
init_distributed()
from ..utils.distributed import init_distributed, distributed_initialized
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
init_distributed(init_deepspeed_dist)
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):

View File

@ -134,7 +134,7 @@ def run_eval(engines, eval_name, dl):
iteration = engines.global_step
engines_stats['it'] = iteration
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")

View File

@ -14,6 +14,14 @@ def get_free_port():
return sock.getsockname()[1]
_distributed_initialized = False
def init_distributed( fn ):
fn()
_distributed_initialized = True
def distributed_initialized():
return _distributed_initialized
@cache
def fix_unset_envs():
envs = dict(

View File

@ -18,8 +18,8 @@ from tqdm import tqdm
from typing import Protocol
from ..config import cfg
from .distributed import init_distributed, distributed_initialized
from .distributed import (
fix_unset_envs,
global_leader_only,
global_rank,
is_global_leader,
@ -112,18 +112,46 @@ def _get_stdin_selector():
return selector
if os.name == "nt":
import msvcrt
_buffer = []
def _non_blocking_input():
global _command
global _buffer
l = [""]
if is_global_leader():
def _windows():
global _buffer
if msvcrt.kbhit():
s: str = msvcrt.getch().decode('utf-8')
if s == '\r':
s = "".join(_buffer)
_buffer = []
return s
_buffer.append(s)
return ""
def _linux():
s = ""
selector = _get_stdin_selector()
events = selector.select(timeout=0)
for key, _ in events:
s: str = key.fileobj.readline().strip()
return s
if is_global_leader():
s = _windows() if os.name == 'nt' else _linux()
if s != "":
_logger.info(f'Get stdin "{s}".')
l[0] = s
broadcast_object_list(l, src=0)
if distributed_initialized():
broadcast_object_list(l, src=0)
_command = l[0]
return _command
@ -152,8 +180,6 @@ def train(
eval_fn: EvalFn = lambda x: ...,
logger: Logger = logger,
):
fix_unset_envs()
engines = load_engines()
"""