diff --git a/vall_e/data.py b/vall_e/data.py index 629858d..6b85872 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -501,7 +501,7 @@ def _create_dataloader(dataset, training): drop_last=training, num_workers=cfg.dataset.workers, collate_fn=collate_fn, - persistent_workers=True, + persistent_workers=cfg.dataset.workers > 1, pin_memory=False, # True, worker_init_fn=_seed_worker, sampler=sampler, diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 41b2bc6..139628b 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -28,7 +28,7 @@ def default_feeder(engine, batch): from ..config import cfg from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device -from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader +from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size import logging import time @@ -45,7 +45,7 @@ from .base import TrainFeeder _logger = logging.getLogger(__name__) -if not distributed_initialized() and cfg.trainer.backend == "local": +if not distributed_initialized() and cfg.trainer.backend == "local" and world_size() > 1: init_distributed(torch.distributed.init_process_group) # A very naive engine implementation using barebones PyTorch @@ -104,7 +104,7 @@ class Engine(): open(save_dir / "latest", 'w').write( tag ) - def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True): + def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False): if tag is None: tag_path = load_dir / "latest" if not tag_path.exists(): @@ -365,7 +365,8 @@ class Engines(dict[str, Engine]): do_gc() continue - all_reduce(n_ooms) + if world_size() > 1: + all_reduce(n_ooms) if n_ooms.item() > 0: self.save_checkpoint() raise RuntimeError("Out of memory during forward pass!") @@ -395,7 +396,8 @@ class Engines(dict[str, Engine]): n_ooms += 1 - all_reduce(n_ooms) + if world_size() > 1: + all_reduce(n_ooms) if n_ooms.item() > 0: self.save_checkpoint() raise RuntimeError("Out of memory during backwards pass!") diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index aa9d590..fef6546 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -19,7 +19,7 @@ from tqdm import tqdm from typing import Protocol from ..config import cfg -from .distributed import init_distributed, distributed_initialized +from .distributed import init_distributed, distributed_initialized, world_size from .distributed import ( global_leader_only, global_rank, @@ -73,7 +73,7 @@ def load_engines(): # yuck, should instead check be optimier == "adamw" and backend != "deepspeed" # and then have ds_cfg pass in the config flag to use pytorch adamw # I genuinely cannot validate if this ever actually gets used in DeepSpeed - if cfg.hyperparameters.optimizer.lower() == "adamw-torch": + if (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "adamw") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "adamw-torch"): optimizer = ml.AdamW( model.parameters(), lr=cfg.hyperparameters.learning_rate, @@ -187,7 +187,8 @@ def _non_blocking_input(): l[0] = s - broadcast_object_list(l, src=0) + if world_size() > 1: + broadcast_object_list(l, src=0) _command = l[0] return _command