fixes with the local backend
This commit is contained in:
parent
00ad4af651
commit
0517d620b8
|
@ -501,7 +501,7 @@ def _create_dataloader(dataset, training):
|
|||
drop_last=training,
|
||||
num_workers=cfg.dataset.workers,
|
||||
collate_fn=collate_fn,
|
||||
persistent_workers=True,
|
||||
persistent_workers=cfg.dataset.workers > 1,
|
||||
pin_memory=False, # True,
|
||||
worker_init_fn=_seed_worker,
|
||||
sampler=sampler,
|
||||
|
|
|
@ -28,7 +28,7 @@ def default_feeder(engine, batch):
|
|||
|
||||
from ..config import cfg
|
||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader
|
||||
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
@ -45,7 +45,7 @@ from .base import TrainFeeder
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
if not distributed_initialized() and cfg.trainer.backend == "local":
|
||||
if not distributed_initialized() and cfg.trainer.backend == "local" and world_size() > 1:
|
||||
init_distributed(torch.distributed.init_process_group)
|
||||
|
||||
# A very naive engine implementation using barebones PyTorch
|
||||
|
@ -104,7 +104,7 @@ class Engine():
|
|||
|
||||
open(save_dir / "latest", 'w').write( tag )
|
||||
|
||||
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
|
||||
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
|
||||
if tag is None:
|
||||
tag_path = load_dir / "latest"
|
||||
if not tag_path.exists():
|
||||
|
@ -365,6 +365,7 @@ class Engines(dict[str, Engine]):
|
|||
do_gc()
|
||||
continue
|
||||
|
||||
if world_size() > 1:
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
|
@ -395,6 +396,7 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
n_ooms += 1
|
||||
|
||||
if world_size() > 1:
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
|
|
|
@ -19,7 +19,7 @@ from tqdm import tqdm
|
|||
from typing import Protocol
|
||||
|
||||
from ..config import cfg
|
||||
from .distributed import init_distributed, distributed_initialized
|
||||
from .distributed import init_distributed, distributed_initialized, world_size
|
||||
from .distributed import (
|
||||
global_leader_only,
|
||||
global_rank,
|
||||
|
@ -73,7 +73,7 @@ def load_engines():
|
|||
# yuck, should instead check be optimier == "adamw" and backend != "deepspeed"
|
||||
# and then have ds_cfg pass in the config flag to use pytorch adamw
|
||||
# I genuinely cannot validate if this ever actually gets used in DeepSpeed
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
|
||||
if (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "adamw") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "adamw-torch"):
|
||||
optimizer = ml.AdamW(
|
||||
model.parameters(),
|
||||
lr=cfg.hyperparameters.learning_rate,
|
||||
|
@ -187,6 +187,7 @@ def _non_blocking_input():
|
|||
|
||||
l[0] = s
|
||||
|
||||
if world_size() > 1:
|
||||
broadcast_object_list(l, src=0)
|
||||
_command = l[0]
|
||||
return _command
|
||||
|
|
Loading…
Reference in New Issue
Block a user