fixes with the local backend

This commit is contained in:
mrq 2023-08-24 17:05:56 -05:00
parent 00ad4af651
commit 0517d620b8
3 changed files with 12 additions and 9 deletions

View File

@ -501,7 +501,7 @@ def _create_dataloader(dataset, training):
drop_last=training, drop_last=training,
num_workers=cfg.dataset.workers, num_workers=cfg.dataset.workers,
collate_fn=collate_fn, collate_fn=collate_fn,
persistent_workers=True, persistent_workers=cfg.dataset.workers > 1,
pin_memory=False, # True, pin_memory=False, # True,
worker_init_fn=_seed_worker, worker_init_fn=_seed_worker,
sampler=sampler, sampler=sampler,

View File

@ -28,7 +28,7 @@ def default_feeder(engine, batch):
from ..config import cfg from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
import logging import logging
import time import time
@ -45,7 +45,7 @@ from .base import TrainFeeder
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
if not distributed_initialized() and cfg.trainer.backend == "local": if not distributed_initialized() and cfg.trainer.backend == "local" and world_size() > 1:
init_distributed(torch.distributed.init_process_group) init_distributed(torch.distributed.init_process_group)
# A very naive engine implementation using barebones PyTorch # A very naive engine implementation using barebones PyTorch
@ -104,7 +104,7 @@ class Engine():
open(save_dir / "latest", 'w').write( tag ) open(save_dir / "latest", 'w').write( tag )
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True): def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
if tag is None: if tag is None:
tag_path = load_dir / "latest" tag_path = load_dir / "latest"
if not tag_path.exists(): if not tag_path.exists():
@ -365,7 +365,8 @@ class Engines(dict[str, Engine]):
do_gc() do_gc()
continue continue
all_reduce(n_ooms) if world_size() > 1:
all_reduce(n_ooms)
if n_ooms.item() > 0: if n_ooms.item() > 0:
self.save_checkpoint() self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!") raise RuntimeError("Out of memory during forward pass!")
@ -395,7 +396,8 @@ class Engines(dict[str, Engine]):
n_ooms += 1 n_ooms += 1
all_reduce(n_ooms) if world_size() > 1:
all_reduce(n_ooms)
if n_ooms.item() > 0: if n_ooms.item() > 0:
self.save_checkpoint() self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!") raise RuntimeError("Out of memory during backwards pass!")

View File

@ -19,7 +19,7 @@ from tqdm import tqdm
from typing import Protocol from typing import Protocol
from ..config import cfg from ..config import cfg
from .distributed import init_distributed, distributed_initialized from .distributed import init_distributed, distributed_initialized, world_size
from .distributed import ( from .distributed import (
global_leader_only, global_leader_only,
global_rank, global_rank,
@ -73,7 +73,7 @@ def load_engines():
# yuck, should instead check be optimier == "adamw" and backend != "deepspeed" # yuck, should instead check be optimier == "adamw" and backend != "deepspeed"
# and then have ds_cfg pass in the config flag to use pytorch adamw # and then have ds_cfg pass in the config flag to use pytorch adamw
# I genuinely cannot validate if this ever actually gets used in DeepSpeed # I genuinely cannot validate if this ever actually gets used in DeepSpeed
if cfg.hyperparameters.optimizer.lower() == "adamw-torch": if (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "adamw") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "adamw-torch"):
optimizer = ml.AdamW( optimizer = ml.AdamW(
model.parameters(), model.parameters(),
lr=cfg.hyperparameters.learning_rate, lr=cfg.hyperparameters.learning_rate,
@ -187,7 +187,8 @@ def _non_blocking_input():
l[0] = s l[0] = s
broadcast_object_list(l, src=0) if world_size() > 1:
broadcast_object_list(l, src=0)
_command = l[0] _command = l[0]
return _command return _command