fixes with the local backend
This commit is contained in:
parent
00ad4af651
commit
0517d620b8
|
@ -501,7 +501,7 @@ def _create_dataloader(dataset, training):
|
||||||
drop_last=training,
|
drop_last=training,
|
||||||
num_workers=cfg.dataset.workers,
|
num_workers=cfg.dataset.workers,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
persistent_workers=True,
|
persistent_workers=cfg.dataset.workers > 1,
|
||||||
pin_memory=False, # True,
|
pin_memory=False, # True,
|
||||||
worker_init_fn=_seed_worker,
|
worker_init_fn=_seed_worker,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
|
|
@ -28,7 +28,7 @@ def default_feeder(engine, batch):
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||||
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader
|
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
@ -45,7 +45,7 @@ from .base import TrainFeeder
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if not distributed_initialized() and cfg.trainer.backend == "local":
|
if not distributed_initialized() and cfg.trainer.backend == "local" and world_size() > 1:
|
||||||
init_distributed(torch.distributed.init_process_group)
|
init_distributed(torch.distributed.init_process_group)
|
||||||
|
|
||||||
# A very naive engine implementation using barebones PyTorch
|
# A very naive engine implementation using barebones PyTorch
|
||||||
|
@ -104,7 +104,7 @@ class Engine():
|
||||||
|
|
||||||
open(save_dir / "latest", 'w').write( tag )
|
open(save_dir / "latest", 'w').write( tag )
|
||||||
|
|
||||||
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
|
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
|
||||||
if tag is None:
|
if tag is None:
|
||||||
tag_path = load_dir / "latest"
|
tag_path = load_dir / "latest"
|
||||||
if not tag_path.exists():
|
if not tag_path.exists():
|
||||||
|
@ -365,7 +365,8 @@ class Engines(dict[str, Engine]):
|
||||||
do_gc()
|
do_gc()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
all_reduce(n_ooms)
|
if world_size() > 1:
|
||||||
|
all_reduce(n_ooms)
|
||||||
if n_ooms.item() > 0:
|
if n_ooms.item() > 0:
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
raise RuntimeError("Out of memory during forward pass!")
|
raise RuntimeError("Out of memory during forward pass!")
|
||||||
|
@ -395,7 +396,8 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
n_ooms += 1
|
n_ooms += 1
|
||||||
|
|
||||||
all_reduce(n_ooms)
|
if world_size() > 1:
|
||||||
|
all_reduce(n_ooms)
|
||||||
if n_ooms.item() > 0:
|
if n_ooms.item() > 0:
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
raise RuntimeError("Out of memory during backwards pass!")
|
raise RuntimeError("Out of memory during backwards pass!")
|
||||||
|
|
|
@ -19,7 +19,7 @@ from tqdm import tqdm
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from .distributed import init_distributed, distributed_initialized
|
from .distributed import init_distributed, distributed_initialized, world_size
|
||||||
from .distributed import (
|
from .distributed import (
|
||||||
global_leader_only,
|
global_leader_only,
|
||||||
global_rank,
|
global_rank,
|
||||||
|
@ -73,7 +73,7 @@ def load_engines():
|
||||||
# yuck, should instead check be optimier == "adamw" and backend != "deepspeed"
|
# yuck, should instead check be optimier == "adamw" and backend != "deepspeed"
|
||||||
# and then have ds_cfg pass in the config flag to use pytorch adamw
|
# and then have ds_cfg pass in the config flag to use pytorch adamw
|
||||||
# I genuinely cannot validate if this ever actually gets used in DeepSpeed
|
# I genuinely cannot validate if this ever actually gets used in DeepSpeed
|
||||||
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
|
if (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "adamw") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "adamw-torch"):
|
||||||
optimizer = ml.AdamW(
|
optimizer = ml.AdamW(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=cfg.hyperparameters.learning_rate,
|
lr=cfg.hyperparameters.learning_rate,
|
||||||
|
@ -187,7 +187,8 @@ def _non_blocking_input():
|
||||||
|
|
||||||
l[0] = s
|
l[0] = s
|
||||||
|
|
||||||
broadcast_object_list(l, src=0)
|
if world_size() > 1:
|
||||||
|
broadcast_object_list(l, src=0)
|
||||||
_command = l[0]
|
_command = l[0]
|
||||||
return _command
|
return _command
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user