fixes with the local backend

This commit is contained in:
mrq 2023-08-24 17:05:56 -05:00
parent 00ad4af651
commit 0517d620b8
3 changed files with 12 additions and 9 deletions

View File

@ -501,7 +501,7 @@ def _create_dataloader(dataset, training):
drop_last=training,
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=True,
persistent_workers=cfg.dataset.workers > 1,
pin_memory=False, # True,
worker_init_fn=_seed_worker,
sampler=sampler,

View File

@ -28,7 +28,7 @@ def default_feeder(engine, batch):
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
import logging
import time
@ -45,7 +45,7 @@ from .base import TrainFeeder
_logger = logging.getLogger(__name__)
if not distributed_initialized() and cfg.trainer.backend == "local":
if not distributed_initialized() and cfg.trainer.backend == "local" and world_size() > 1:
init_distributed(torch.distributed.init_process_group)
# A very naive engine implementation using barebones PyTorch
@ -104,7 +104,7 @@ class Engine():
open(save_dir / "latest", 'w').write( tag )
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
if tag is None:
tag_path = load_dir / "latest"
if not tag_path.exists():
@ -365,6 +365,7 @@ class Engines(dict[str, Engine]):
do_gc()
continue
if world_size() > 1:
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
@ -395,6 +396,7 @@ class Engines(dict[str, Engine]):
n_ooms += 1
if world_size() > 1:
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()

View File

@ -19,7 +19,7 @@ from tqdm import tqdm
from typing import Protocol
from ..config import cfg
from .distributed import init_distributed, distributed_initialized
from .distributed import init_distributed, distributed_initialized, world_size
from .distributed import (
global_leader_only,
global_rank,
@ -73,7 +73,7 @@ def load_engines():
# yuck, should instead check be optimier == "adamw" and backend != "deepspeed"
# and then have ds_cfg pass in the config flag to use pytorch adamw
# I genuinely cannot validate if this ever actually gets used in DeepSpeed
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
if (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "adamw") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "adamw-torch"):
optimizer = ml.AdamW(
model.parameters(),
lr=cfg.hyperparameters.learning_rate,
@ -187,6 +187,7 @@ def _non_blocking_input():
l[0] = s
if world_size() > 1:
broadcast_object_list(l, src=0)
_command = l[0]
return _command