vall-e/vall_e/engines/deepspeed.py

160 lines
4.5 KiB
Python
Executable File

"""
# https://github.com/enhuiz/pytorch-training-utilities
"""
# to-do: replace this
# to-do: swap out deepspeed
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
import logging
import time
import torch
import torch.distributed
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
from .base import TrainFeeder
_logger = logging.getLogger(__name__)
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed as init_deepspeed_dist
from deepspeed.accelerator import get_accelerator
from ..utils.distributed import init_distributed, distributed_initialized
from ..utils import wrapper as ml
from ..models.lora import freeze_non_lora_weights
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
init_distributed(init_deepspeed_dist)
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
self.hyper_config = None
if 'hyper_config' in kwargs:
self.hyper_config = kwargs['hyper_config']
kwargs.pop("hyper_config")
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
stats = {
"global_step": 0,
"micro_step": 0,
"global_samples": 0,
"tokens_processed": 0,
}
# kwargs['stats'] = None will return None when popped
maybe_stats = kwargs.pop('stats', stats)
if maybe_stats is not None:
stats = maybe_stats
super().__init__(None, *args, **kwargs)
self._frozen_params = set()
self.global_steps = stats["global_step"]
self.micro_steps = stats["micro_step"]
self.global_samples = stats["global_samples"]
self.tokens_processed = stats["tokens_processed"]
self.max_nan_losses = 8
self.current_batch_size = 0
self.skip_on_nan = True
def freeze(self, freeze_all=True):
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
for param in frozen_params:
self._frozen_params.add( param )
return
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
def unfreeze(self):
for param in self._frozen_params:
param.requires_grad_(True)
self._frozen_params.clear()
@property
def _training(self):
return self.hyper_config.training
@property
def global_step(self):
return self.global_steps
@property
def micro_step(self):
return self.micro_steps
@property
def batch_size(self):
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
def set_lr(self, lr):
try:
if hasattr(self.optimizer, 'param_groups'):
for param_group in self.optimizer.param_groups:
param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
else:
self.optimizer.set_lr(lr)
except Exception as e:
_logger.warning(str(e))
# we'll just have to live with the LoRA weights living within our main weights
# they're easy to extract anyways
def load_checkpoint(self, load_dir, **kwargs ):
# override to load the lora instead
if cfg.lora is not None:
load_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().load_checkpoint( load_dir, **kwargs )
def save_checkpoint(self, save_dir, **kwargs ):
# override to save the lora instead
if cfg.lora is not None:
save_dir = cfg.ckpt_dir / cfg.lora.full_name
return super().save_checkpoint( save_dir, **kwargs )
def traverse(self, *args, **kwargs):
with ml.autocast():
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
return stats
self.backward(loss)
self.step()
return stats