vall-e/vall_e/engines/deepspeed.py

157 lines
4.4 KiB
Python
Raw Normal View History

2023-08-04 01:26:36 +00:00
"""
# https://github.com/enhuiz/pytorch-training-utilities
"""
# to-do: replace this
# to-do: swap out deepspeed
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
import logging
import time
import torch
import torch.distributed
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
from .base import TrainFeeder
_logger = logging.getLogger(__name__)
2023-08-05 03:22:15 +00:00
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed as init_deepspeed_dist
2023-08-04 01:26:36 +00:00
from deepspeed.accelerator import get_accelerator
2023-08-05 03:22:15 +00:00
from ..utils.distributed import init_distributed, distributed_initialized
from ..utils import wrapper as ml
2023-08-05 03:22:15 +00:00
from ..models.lora import freeze_non_lora_weights
2023-08-05 03:22:15 +00:00
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
init_distributed(init_deepspeed_dist)
2023-08-04 01:26:36 +00:00
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
2024-06-04 02:28:49 +00:00
self.hyper_config = None
if 'hyper_config' in kwargs:
self.hyper_config = kwargs['hyper_config']
kwargs.pop("hyper_config")
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
2023-08-04 01:26:36 +00:00
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
2023-09-21 00:20:17 +00:00
stats = {
"global_step": 0,
"micro_step": 0,
2023-09-21 00:20:17 +00:00
"global_samples": 0,
"tokens_processed": 0,
}
# kwargs['stats'] = None will return None when popped
maybe_stats = kwargs.pop('stats', stats)
if maybe_stats is not None:
stats = maybe_stats
2023-08-04 01:26:36 +00:00
super().__init__(None, *args, **kwargs)
self._frozen_params = set()
self.global_steps = stats["global_step"]
self.micro_steps = stats["micro_step"]
self.global_samples = stats["global_samples"]
self.tokens_processed = stats["tokens_processed"]
self.max_nan_losses = 8
self.current_batch_size = 0
def freeze(self, freeze_all=True):
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
for param in frozen_params:
self._frozen_params.add( param )
return
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
for name, param in self.module.named_parameters():
2024-06-04 02:28:49 +00:00
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
2023-08-04 01:26:36 +00:00
def unfreeze(self):
for param in self._frozen_params:
param.requires_grad_(True)
2023-08-04 01:26:36 +00:00
self._frozen_params.clear()
@property
2024-06-07 02:57:11 +00:00
def _training(self):
2024-06-04 02:28:49 +00:00
return self.hyper_config.training
2023-08-04 01:26:36 +00:00
@property
def global_step(self):
return self.global_steps
@property
def micro_step(self):
return self.micro_steps
2023-08-04 01:26:36 +00:00
@property
def batch_size(self):
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
2023-08-04 01:26:36 +00:00
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
def set_lr(self, lr):
try:
if hasattr(self.optimizer, 'param_groups'):
for param_group in self.optimizer.param_groups:
param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
2023-08-04 01:26:36 +00:00
else:
self.optimizer.set_lr(lr)
except Exception as e:
_logger.warning(str(e))
2023-08-04 01:26:36 +00:00
# we'll just have to live with the LoRA weights living within our main weights
# they're easy to extract anyways
def load_checkpoint(self, load_dir, **kwargs ):
# override to load the lora instead
if cfg.lora is not None:
load_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().load_checkpoint( load_dir, **kwargs )
def save_checkpoint(self, save_dir, **kwargs ):
# override to save the lora instead
if cfg.lora is not None:
save_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().save_checkpoint( save_dir, **kwargs )
2023-08-04 01:26:36 +00:00
def traverse(self, *args, **kwargs):
with ml.autocast():
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
2023-08-04 01:26:36 +00:00
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
2023-08-04 01:26:36 +00:00
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
self.backward(loss)
self.step()
return stats