2023-08-04 01:26:36 +00:00
|
|
|
"""
|
|
|
|
# https://github.com/enhuiz/pytorch-training-utilities
|
|
|
|
"""
|
|
|
|
|
|
|
|
# to-do: replace this
|
|
|
|
# to-do: swap out deepspeed
|
|
|
|
|
|
|
|
from ..config import cfg
|
|
|
|
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
|
|
|
|
|
|
|
import logging
|
|
|
|
import time
|
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
|
|
|
|
from torch import Tensor
|
|
|
|
from torch.distributed import all_reduce
|
|
|
|
from typing import Any, Protocol
|
|
|
|
|
|
|
|
from .base import TrainFeeder
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
2023-08-05 03:22:15 +00:00
|
|
|
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed as init_deepspeed_dist
|
2023-08-04 01:26:36 +00:00
|
|
|
from deepspeed.accelerator import get_accelerator
|
|
|
|
|
2023-08-05 03:22:15 +00:00
|
|
|
from ..utils.distributed import init_distributed, distributed_initialized
|
|
|
|
|
|
|
|
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
|
|
|
init_distributed(init_deepspeed_dist)
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
class Engine(DeepSpeedEngine):
|
|
|
|
def __init__(self, *args, **kwargs):
|
2023-09-07 23:19:51 +00:00
|
|
|
self._cfg = None
|
2023-08-19 20:06:33 +00:00
|
|
|
if '_cfg' in kwargs:
|
|
|
|
self._cfg = kwargs['_cfg']
|
|
|
|
kwargs.pop("_cfg")
|
|
|
|
|
|
|
|
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
|
2023-08-04 01:26:36 +00:00
|
|
|
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
|
|
|
|
|
|
|
|
super().__init__(None, *args, **kwargs)
|
|
|
|
self._frozen_params = set()
|
|
|
|
|
2023-08-28 16:02:45 +00:00
|
|
|
self.tokens_processed = 0
|
|
|
|
|
2023-09-07 23:19:51 +00:00
|
|
|
def freeze(self, freeze_all=True):
|
|
|
|
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
|
|
|
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
|
|
|
|
|
|
|
|
for name, param in self.module.named_parameters():
|
|
|
|
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params):
|
|
|
|
param.requires_grad_(False)
|
|
|
|
self._frozen_params.add(param)
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
def unfreeze(self):
|
2023-09-07 23:19:51 +00:00
|
|
|
for param in self._frozen_params:
|
|
|
|
param.requires_grad_(True)
|
2023-08-04 01:26:36 +00:00
|
|
|
self._frozen_params.clear()
|
2023-08-27 17:26:12 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def _training(self):
|
|
|
|
return self._cfg.training
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def global_step(self):
|
|
|
|
return self.global_steps
|
|
|
|
|
|
|
|
@property
|
|
|
|
def micro_step(self):
|
2023-08-28 16:02:45 +00:00
|
|
|
return self.micro_steps
|
2023-08-04 01:26:36 +00:00
|
|
|
|
2023-08-27 17:26:12 +00:00
|
|
|
@property
|
|
|
|
def batch_size(self):
|
|
|
|
return cfg.hyperparameters.batch_size
|
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
def gather_attribute(self, *args, **kwargs):
|
|
|
|
return gather_attribute(self.module, *args, **kwargs)
|
|
|
|
|
|
|
|
def dispatch_attribute(self, *args, **kwargs):
|
|
|
|
return dispatch_attribute(self.module, *args, **kwargs)
|
|
|
|
|
|
|
|
def set_lr(self, lr):
|
|
|
|
try:
|
|
|
|
if hasattr(self.optimizer, 'param_groups'):
|
|
|
|
for param_group in self.optimizer.param_groups:
|
|
|
|
param_group['lr'] = lr
|
|
|
|
else:
|
|
|
|
self.optimizer.set_lr(lr)
|
|
|
|
except Exception as e:
|
|
|
|
print(str(e))
|
|
|
|
|
|
|
|
def traverse(self, *args, **kwargs):
|
2023-09-02 17:23:40 +00:00
|
|
|
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
2023-09-02 01:58:29 +00:00
|
|
|
self.forward(*args, **kwargs)
|
|
|
|
losses = self.gather_attribute("loss")
|
|
|
|
loss = torch.stack([*losses.values()]).sum()
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
stats = {}
|
|
|
|
stats |= {k: v.item() for k, v in losses.items()}
|
|
|
|
stats |= self.gather_attribute("scalar")
|
|
|
|
|
|
|
|
self.backward(loss)
|
|
|
|
self.step()
|
|
|
|
|
|
|
|
return stats
|