2023-08-05 03:40:14 +00:00
|
|
|
"""
|
|
|
|
# https://github.com/enhuiz/pytorch-training-utilities
|
|
|
|
"""
|
|
|
|
|
|
|
|
# to-do: replace this
|
|
|
|
# to-do: swap out deepspeed
|
|
|
|
|
|
|
|
from ..config import cfg
|
|
|
|
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
|
|
|
|
|
|
|
import logging
|
|
|
|
import time
|
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
|
|
|
|
from torch import Tensor
|
|
|
|
from torch.distributed import all_reduce
|
|
|
|
from typing import Any, Protocol
|
|
|
|
|
|
|
|
from .base import TrainFeeder
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed as init_deepspeed_dist
|
|
|
|
from deepspeed.accelerator import get_accelerator
|
|
|
|
|
|
|
|
from ..utils.distributed import init_distributed, distributed_initialized
|
2024-09-04 20:48:29 +00:00
|
|
|
from ..utils import wrapper as ml
|
|
|
|
|
|
|
|
from ..models.lora import freeze_non_lora_weights
|
2023-08-05 03:40:14 +00:00
|
|
|
|
|
|
|
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
|
|
|
init_distributed(init_deepspeed_dist)
|
|
|
|
|
|
|
|
class Engine(DeepSpeedEngine):
|
|
|
|
def __init__(self, *args, **kwargs):
|
2024-09-04 20:48:29 +00:00
|
|
|
self.hyper_config = None
|
|
|
|
if 'hyper_config' in kwargs:
|
|
|
|
self.hyper_config = kwargs['hyper_config']
|
|
|
|
kwargs.pop("hyper_config")
|
|
|
|
|
|
|
|
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
|
2023-08-05 03:40:14 +00:00
|
|
|
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
|
|
|
|
|
2024-09-04 20:48:29 +00:00
|
|
|
stats = {
|
|
|
|
"global_step": 0,
|
|
|
|
"micro_step": 0,
|
|
|
|
"global_samples": 0,
|
|
|
|
"tokens_processed": 0,
|
|
|
|
}
|
|
|
|
|
|
|
|
# kwargs['stats'] = None will return None when popped
|
|
|
|
maybe_stats = kwargs.pop('stats', stats)
|
|
|
|
if maybe_stats is not None:
|
|
|
|
stats = maybe_stats
|
|
|
|
|
2023-08-05 03:40:14 +00:00
|
|
|
super().__init__(None, *args, **kwargs)
|
|
|
|
self._frozen_params = set()
|
|
|
|
|
2024-09-04 20:48:29 +00:00
|
|
|
self.global_steps = stats["global_step"]
|
|
|
|
self.micro_steps = stats["micro_step"]
|
|
|
|
self.global_samples = stats["global_samples"]
|
|
|
|
self.tokens_processed = stats["tokens_processed"]
|
|
|
|
|
|
|
|
self.max_nan_losses = 8
|
|
|
|
self.current_batch_size = 0
|
|
|
|
|
|
|
|
def freeze(self, freeze_all=True):
|
|
|
|
# freeze non-LoRA params if requested
|
|
|
|
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
|
|
|
frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
|
|
|
|
for param in frozen_params:
|
|
|
|
self._frozen_params.add( param )
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
|
|
|
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
|
|
|
|
|
|
|
for name, param in self.module.named_parameters():
|
|
|
|
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
|
|
|
param.requires_grad_(False)
|
|
|
|
self._frozen_params.add(param)
|
2023-08-05 03:40:14 +00:00
|
|
|
|
|
|
|
def unfreeze(self):
|
2024-09-04 20:48:29 +00:00
|
|
|
for param in self._frozen_params:
|
|
|
|
param.requires_grad_(True)
|
2023-08-05 03:40:14 +00:00
|
|
|
self._frozen_params.clear()
|
2024-09-04 20:48:29 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def _training(self):
|
|
|
|
return self.hyper_config.training
|
2023-08-05 03:40:14 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def global_step(self):
|
|
|
|
return self.global_steps
|
|
|
|
|
|
|
|
@property
|
|
|
|
def micro_step(self):
|
2024-09-04 20:48:29 +00:00
|
|
|
return self.micro_steps
|
|
|
|
|
|
|
|
@property
|
|
|
|
def batch_size(self):
|
|
|
|
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
|
2023-08-05 03:40:14 +00:00
|
|
|
|
|
|
|
def gather_attribute(self, *args, **kwargs):
|
|
|
|
return gather_attribute(self.module, *args, **kwargs)
|
|
|
|
|
|
|
|
def dispatch_attribute(self, *args, **kwargs):
|
|
|
|
return dispatch_attribute(self.module, *args, **kwargs)
|
|
|
|
|
|
|
|
def set_lr(self, lr):
|
|
|
|
try:
|
|
|
|
if hasattr(self.optimizer, 'param_groups'):
|
|
|
|
for param_group in self.optimizer.param_groups:
|
2024-09-04 20:48:29 +00:00
|
|
|
param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
|
2023-08-05 03:40:14 +00:00
|
|
|
else:
|
|
|
|
self.optimizer.set_lr(lr)
|
|
|
|
except Exception as e:
|
2024-09-04 20:48:29 +00:00
|
|
|
_logger.warning(str(e))
|
|
|
|
|
|
|
|
# we'll just have to live with the LoRA weights living within our main weights
|
|
|
|
# they're easy to extract anyways
|
|
|
|
def load_checkpoint(self, load_dir, **kwargs ):
|
|
|
|
# override to load the lora instead
|
|
|
|
if cfg.lora is not None:
|
|
|
|
load_dir = cfg.ckpt_dir / cfg.lora.full_name
|
|
|
|
|
|
|
|
return super().load_checkpoint( load_dir, **kwargs )
|
|
|
|
|
|
|
|
def save_checkpoint(self, save_dir, **kwargs ):
|
|
|
|
# override to save the lora instead
|
|
|
|
if cfg.lora is not None:
|
|
|
|
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
|
|
|
|
|
|
|
return super().save_checkpoint( save_dir, **kwargs )
|
2023-08-05 03:40:14 +00:00
|
|
|
|
|
|
|
def traverse(self, *args, **kwargs):
|
2024-09-04 20:48:29 +00:00
|
|
|
with ml.autocast():
|
|
|
|
self.forward(*args, **kwargs)
|
|
|
|
|
2023-08-05 03:40:14 +00:00
|
|
|
losses = self.gather_attribute("loss")
|
|
|
|
loss = torch.stack([*losses.values()]).sum()
|
|
|
|
|
2024-09-04 20:48:29 +00:00
|
|
|
if torch.isnan(loss).any():
|
|
|
|
self.max_nan_losses = self.max_nan_losses - 1
|
|
|
|
if self.max_nan_losses < 0:
|
|
|
|
raise RuntimeError("Too many NaN losses detected.")
|
|
|
|
|
2023-08-05 03:40:14 +00:00
|
|
|
stats = {}
|
|
|
|
stats |= {k: v.item() for k, v in losses.items()}
|
|
|
|
stats |= self.gather_attribute("scalar")
|
|
|
|
|
|
|
|
self.backward(loss)
|
|
|
|
self.step()
|
|
|
|
|
|
|
|
return stats
|