some fixes for the local framework
This commit is contained in:
parent
5970f254e3
commit
d89568a96e
|
@ -375,6 +375,14 @@ class Trainer:
|
||||||
|
|
||||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def dtype(self):
|
||||||
|
if self.weight_dtype == "float16":
|
||||||
|
return torch.float16
|
||||||
|
if cfg.trainer.weight_dtype == "bfloat16":
|
||||||
|
return torch.bfloat16
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Inference:
|
class Inference:
|
||||||
|
|
3
vall_e/engines/__init__.py
Normal file → Executable file
3
vall_e/engines/__init__.py
Normal file → Executable file
|
@ -1,5 +1,8 @@
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
|
||||||
|
from ..utils.distributed import fix_unset_envs
|
||||||
|
fix_unset_envs()
|
||||||
|
|
||||||
if cfg.trainer.backend == "deepspeed":
|
if cfg.trainer.backend == "deepspeed":
|
||||||
from .deepspeed import Engine
|
from .deepspeed import Engine
|
||||||
elif cfg.trainer.backend == "local":
|
elif cfg.trainer.backend == "local":
|
||||||
|
|
|
@ -28,6 +28,7 @@ def default_feeder(engine, batch):
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||||
|
from ..utils.distributed import init_distributed, distributed_initialized
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
@ -43,10 +44,13 @@ from .base import TrainFeeder
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if not distributed_initialized() and cfg.trainer.backend == "local":
|
||||||
|
init_distributed(torch.distributed.init_process_group)
|
||||||
|
|
||||||
# A very naive engine implementation using barebones PyTorch
|
# A very naive engine implementation using barebones PyTorch
|
||||||
class Engine():
|
class Engine():
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.module = kwargs['model'].to(cfg.device)
|
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
|
||||||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||||
|
|
||||||
|
@ -93,6 +97,8 @@ class Engine():
|
||||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||||
}, save_path)
|
}, save_path)
|
||||||
|
|
||||||
|
open(save_dir / "latest", 'w').write( tag )
|
||||||
|
|
||||||
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
|
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
|
||||||
if tag is None:
|
if tag is None:
|
||||||
tag_path = load_dir / "latest"
|
tag_path = load_dir / "latest"
|
||||||
|
@ -105,8 +111,8 @@ class Engine():
|
||||||
return
|
return
|
||||||
|
|
||||||
state = torch.load(load_path)
|
state = torch.load(load_path)
|
||||||
self.global_step = state['global_step']
|
self.global_steps = state['global_step']
|
||||||
self.micro_step = state['micro_step']
|
self.micro_steps = state['micro_step']
|
||||||
self.module.load_state_dict(state['module'])
|
self.module.load_state_dict(state['module'])
|
||||||
|
|
||||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||||
|
|
11
vall_e/engines/deepspeed.py
Normal file → Executable file
11
vall_e/engines/deepspeed.py
Normal file → Executable file
|
@ -21,14 +21,13 @@ from .base import TrainFeeder
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed
|
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed as init_deepspeed_dist
|
||||||
from deepspeed.accelerator import get_accelerator
|
from deepspeed.accelerator import get_accelerator
|
||||||
|
|
||||||
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
|
from ..utils.distributed import init_distributed, distributed_initialized
|
||||||
initialized_dist = False
|
|
||||||
if not initialized_dist:
|
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
||||||
initialized_dist = True
|
init_distributed(init_deepspeed_dist)
|
||||||
init_distributed()
|
|
||||||
|
|
||||||
class Engine(DeepSpeedEngine):
|
class Engine(DeepSpeedEngine):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
|
@ -134,7 +134,7 @@ def run_eval(engines, eval_name, dl):
|
||||||
|
|
||||||
iteration = engines.global_step
|
iteration = engines.global_step
|
||||||
engines_stats['it'] = iteration
|
engines_stats['it'] = iteration
|
||||||
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||||
|
|
||||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,14 @@ def get_free_port():
|
||||||
return sock.getsockname()[1]
|
return sock.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
_distributed_initialized = False
|
||||||
|
def init_distributed( fn ):
|
||||||
|
fn()
|
||||||
|
_distributed_initialized = True
|
||||||
|
|
||||||
|
def distributed_initialized():
|
||||||
|
return _distributed_initialized
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def fix_unset_envs():
|
def fix_unset_envs():
|
||||||
envs = dict(
|
envs = dict(
|
||||||
|
|
|
@ -18,8 +18,8 @@ from tqdm import tqdm
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
from .distributed import init_distributed, distributed_initialized
|
||||||
from .distributed import (
|
from .distributed import (
|
||||||
fix_unset_envs,
|
|
||||||
global_leader_only,
|
global_leader_only,
|
||||||
global_rank,
|
global_rank,
|
||||||
is_global_leader,
|
is_global_leader,
|
||||||
|
@ -112,18 +112,46 @@ def _get_stdin_selector():
|
||||||
return selector
|
return selector
|
||||||
|
|
||||||
|
|
||||||
|
if os.name == "nt":
|
||||||
|
import msvcrt
|
||||||
|
_buffer = []
|
||||||
|
|
||||||
def _non_blocking_input():
|
def _non_blocking_input():
|
||||||
global _command
|
global _command
|
||||||
|
global _buffer
|
||||||
l = [""]
|
l = [""]
|
||||||
if is_global_leader():
|
|
||||||
|
def _windows():
|
||||||
|
global _buffer
|
||||||
|
|
||||||
|
if msvcrt.kbhit():
|
||||||
|
s: str = msvcrt.getch().decode('utf-8')
|
||||||
|
if s == '\r':
|
||||||
|
s = "".join(_buffer)
|
||||||
|
_buffer = []
|
||||||
|
return s
|
||||||
|
|
||||||
|
_buffer.append(s)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _linux():
|
||||||
s = ""
|
s = ""
|
||||||
selector = _get_stdin_selector()
|
selector = _get_stdin_selector()
|
||||||
events = selector.select(timeout=0)
|
events = selector.select(timeout=0)
|
||||||
for key, _ in events:
|
for key, _ in events:
|
||||||
s: str = key.fileobj.readline().strip()
|
s: str = key.fileobj.readline().strip()
|
||||||
|
return s
|
||||||
|
|
||||||
|
if is_global_leader():
|
||||||
|
s = _windows() if os.name == 'nt' else _linux()
|
||||||
|
|
||||||
|
if s != "":
|
||||||
_logger.info(f'Get stdin "{s}".')
|
_logger.info(f'Get stdin "{s}".')
|
||||||
|
|
||||||
l[0] = s
|
l[0] = s
|
||||||
broadcast_object_list(l, src=0)
|
|
||||||
|
if distributed_initialized():
|
||||||
|
broadcast_object_list(l, src=0)
|
||||||
_command = l[0]
|
_command = l[0]
|
||||||
return _command
|
return _command
|
||||||
|
|
||||||
|
@ -152,8 +180,6 @@ def train(
|
||||||
eval_fn: EvalFn = lambda x: ...,
|
eval_fn: EvalFn = lambda x: ...,
|
||||||
logger: Logger = logger,
|
logger: Logger = logger,
|
||||||
):
|
):
|
||||||
fix_unset_envs()
|
|
||||||
|
|
||||||
engines = load_engines()
|
engines = load_engines()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user