|
|
|
@ -45,7 +45,10 @@ from .base import TrainFeeder
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
if not distributed_initialized() and cfg.trainer.backend == "local":
|
|
|
|
|
init_distributed(torch.distributed.init_process_group)
|
|
|
|
|
def _nop():
|
|
|
|
|
...
|
|
|
|
|
fn = _nop if cfg.device == "cpu" else torch.distributed.init_process_group
|
|
|
|
|
init_distributed(fn)
|
|
|
|
|
|
|
|
|
|
# A very naive engine implementation using barebones PyTorch
|
|
|
|
|
# to-do: implement lr_sheduling
|
|
|
|
@ -276,7 +279,7 @@ class Engines(dict[str, Engine]):
|
|
|
|
|
stats.update(flatten_dict({ name.split("-")[0]: stat }))
|
|
|
|
|
return stats
|
|
|
|
|
|
|
|
|
|
def step(self, batch, feeder: TrainFeeder = default_feeder, device=torch.cuda.current_device()):
|
|
|
|
|
def step(self, batch, feeder: TrainFeeder = default_feeder, device=cfg.get_device()):
|
|
|
|
|
total_elapsed_time = 0
|
|
|
|
|
|
|
|
|
|
stats: Any = dict()
|
|
|
|
|