2023-08-04 01:26:36 +00:00
from torch import Tensor
from typing import Any , Protocol
Stats = dict [ str , float ]
class TrainFeeder ( Protocol ) :
def __call__ (
self , * , engine : " Engine " , batch : Any
) - > None | tuple [ Tensor , Stats ] :
. . .
2023-08-04 02:39:00 +00:00
def default_feeder ( engine , batch ) :
2023-08-04 01:26:36 +00:00
if isinstance ( batch , list ) :
engine ( * batch )
elif isinstance ( batch , dict ) :
engine ( * * batch )
else :
engine ( batch )
losses = engine . gather_attribute ( " loss " )
loss = torch . stack ( [ * losses . values ( ) ] ) . sum ( )
stats = { }
stats | = { k : v . item ( ) for k , v in losses . items ( ) }
return loss , stats
from . . config import cfg
from . . utils import dispatch_attribute , flatten_dict , gather_attribute , do_gc , to_device
2023-08-24 22:05:56 +00:00
from . . utils . distributed import init_distributed , distributed_initialized , is_global_leader , world_size
2023-08-04 01:26:36 +00:00
import logging
import time
import torch
import torch . distributed
import os
from torch import Tensor
from torch . distributed import all_reduce
from typing import Any , Protocol
2023-08-19 20:06:33 +00:00
from functools import cached_property
2023-08-04 01:26:36 +00:00
from . base import TrainFeeder
2024-04-09 01:14:51 +00:00
from . . utils import wrapper as ml
2023-08-04 01:26:36 +00:00
_logger = logging . getLogger ( __name__ )
2023-12-21 00:45:58 +00:00
if not distributed_initialized ( ) and cfg . trainer . backend == " local " : # and world_size() > 1:
2023-08-05 03:22:15 +00:00
init_distributed ( torch . distributed . init_process_group )
2023-08-04 01:26:36 +00:00
# A very naive engine implementation using barebones PyTorch
class Engine ( ) :
def __init__ ( self , * args , * * kwargs ) :
2024-06-04 02:28:49 +00:00
if ' hyper_config ' in kwargs :
self . hyper_config = kwargs [ ' hyper_config ' ]
kwargs . pop ( " hyper_config " )
2023-08-19 20:06:33 +00:00
2023-09-02 01:58:29 +00:00
self . module = kwargs [ ' model ' ] . to ( cfg . device ) . to ( torch . float32 if cfg . trainer . amp else cfg . trainer . dtype )
2023-08-04 01:26:36 +00:00
self . optimizer = kwargs [ ' optimizer ' ] if ' optimizer ' in kwargs else None
self . lr_scheduler = kwargs [ ' lr_scheduler ' ] if ' lr_scheduler ' in kwargs else None
2023-09-21 00:10:59 +00:00
self . global_steps = kwargs . pop ( " global_steps " , 0 )
self . micro_steps = kwargs . pop ( " micro_steps " , 0 )
self . global_samples = kwargs . pop ( " global_samples " , 0 )
self . tokens_processed = kwargs . pop ( " tokens_processed " , 0 )
2023-08-04 01:26:36 +00:00
2023-09-09 21:17:20 +00:00
self . _frozen_params = set ( )
2024-05-12 03:23:29 +00:00
self . max_nan_losses = 8
self . loss_scaler = torch . cuda . amp . GradScaler ( ) if cfg . trainer . scale_loss else None
2023-09-07 23:19:51 +00:00
def freeze ( self , freeze_all = True ) :
# set to freeze
2024-06-04 02:28:49 +00:00
if self . hyper_config is None or not hasattr ( self . hyper_config , " frozen_params " ) :
raise Exception ( " freeze_all=False yet self.hyper_config.frozen_params is None " )
2023-09-07 23:19:51 +00:00
for name , param in self . module . named_parameters ( ) :
2024-06-04 02:28:49 +00:00
if ( freeze_all and param . requires_grad ) or ( not freeze_all and name in self . hyper_config . frozen_params ) :
2023-09-07 23:19:51 +00:00
param . requires_grad_ ( False )
self . _frozen_params . add ( param )
2023-08-04 01:26:36 +00:00
def unfreeze ( self ) :
for p in self . _frozen_params :
p . requires_grad_ ( True )
self . _frozen_params . clear ( )
2023-08-27 17:26:12 +00:00
@property
2024-06-07 02:57:11 +00:00
def _training ( self ) :
2024-06-04 02:28:49 +00:00
if not hasattr ( self , " hyper_config " ) :
2023-09-02 02:33:51 +00:00
return True
2024-06-04 02:28:49 +00:00
return self . hyper_config . training
2023-08-27 17:26:12 +00:00
2023-08-04 01:26:36 +00:00
@property
def global_step ( self ) :
return self . global_steps
@property
def micro_step ( self ) :
return self . micro_steps
2023-08-27 17:26:12 +00:00
@property
def batch_size ( self ) :
2023-08-04 01:26:36 +00:00
return cfg . hyperparameters . batch_size
2023-08-28 16:02:45 +00:00
@property
def gradient_accumulation_steps ( self ) :
return cfg . hyperparameters . gradient_accumulation_steps
2024-03-02 02:18:43 +00:00
@property
def gradient_clipping ( self ) :
return cfg . hyperparameters . gradient_clipping
2023-08-28 16:02:45 +00:00
2023-08-04 01:26:36 +00:00
def gather_attribute ( self , * args , * * kwargs ) :
return gather_attribute ( self . module , * args , * * kwargs )
def dispatch_attribute ( self , * args , * * kwargs ) :
return dispatch_attribute ( self . module , * args , * * kwargs )
def save_checkpoint ( self , save_dir , tag ) :
2024-05-11 22:29:01 +00:00
if is_global_leader ( ) :
save_path = save_dir / tag / " state.pth "
save_path . parent . mkdir ( parents = True , exist_ok = True )
torch . save ( {
" module " : self . module . state_dict ( ) ,
" optimizer " : self . optimizer . state_dict ( ) if self . optimizer is not None else None ,
" lr_scheduler " : self . lr_scheduler . state_dict ( ) if self . lr_scheduler is not None else None ,
" stats " : {
" global_step " : self . global_step ,
" micro_step " : self . micro_step ,
" global_samples " : self . global_samples ,
" tokens_processed " : self . tokens_processed ,
}
} , save_path )
open ( save_dir / " latest " , ' w ' ) . write ( tag )
torch . distributed . barrier ( )
2023-08-05 03:22:15 +00:00
2023-08-24 22:05:56 +00:00
def load_checkpoint ( self , load_dir , tag = None , load_module_strict = True , load_optimizer_states = True , load_lr_scheduler_states = True , load_module_only = False ) :
2023-08-05 02:17:30 +00:00
if tag is None :
tag_path = load_dir / " latest "
if not tag_path . exists ( ) :
return
tag = open ( tag_path ) . read ( )
load_path = load_dir / tag / " state.pth "
if not load_path . exists ( ) :
return
2023-09-09 21:17:20 +00:00
state = torch . load ( load_path , map_location = torch . device ( cfg . device ) )
2023-09-21 00:10:59 +00:00
self . global_steps = state [ ' stats ' ] [ ' global_step ' ] if ' stats ' in state else state [ ' global_step ' ]
self . micro_steps = state [ ' stats ' ] [ ' micro_step ' ] if ' stats ' in state else state [ ' micro_step ' ]
self . global_samples = state [ ' stats ' ] [ ' global_samples ' ] if ' stats ' in state else state [ ' global_samples ' ]
self . tokens_processed = state [ ' stats ' ] [ ' tokens_processed ' ] if ' stats ' in state else state [ ' tokens_processed ' ]
2023-08-04 01:26:36 +00:00
self . module . load_state_dict ( state [ ' module ' ] )
load_optimizer_states = load_optimizer_states and self . optimizer is not None and ' optimizer ' in state
load_lr_scheduler_states = load_lr_scheduler_states and self . lr_scheduler is not None and ' lr_scheduler ' in state
if load_optimizer_states :
2024-05-11 22:29:01 +00:00
self . optimizer . load_state_dict ( state [ ' optimizer ' ] ) #, map_location=torch.device(cfg.device))
2023-08-04 01:26:36 +00:00
if load_lr_scheduler_states :
2024-05-11 22:29:01 +00:00
self . lr_scheduler . load_state_dict ( state [ ' lr_scheduler ' ] ) #, map_location=torch.device(cfg.device))
2023-08-04 01:26:36 +00:00
def eval ( self ) :
return self . module . eval ( )
def train ( self ) :
return self . module . train ( )
def to ( self , * args , * * kwargs ) :
self . module = self . module . to ( * args , * * kwargs )
2023-08-27 17:26:12 +00:00
if self . optimizer :
self . optimizer = self . optimizer . to ( * args , * * kwargs )
return self
2023-08-04 01:26:36 +00:00
def __call__ ( self , * args , * * kwargs ) :
return self . forward ( * args , * * kwargs )
2023-08-19 20:06:33 +00:00
@cached_property
def device ( self ) :
return next ( self . module . parameters ( ) ) . device
2023-08-04 01:26:36 +00:00
def forward ( self , * args , * * kwargs ) :
return self . module . forward ( * args , * * kwargs )
def backward ( self , loss ) :
2024-05-12 03:23:29 +00:00
if self . loss_scaler is not None :
return self . loss_scaler . scale ( loss / self . gradient_accumulation_steps ) . backward ( )
2023-08-04 01:26:36 +00:00
return ( loss / self . gradient_accumulation_steps ) . backward ( )
2024-05-12 03:23:29 +00:00
2023-08-04 01:26:36 +00:00
def step ( self ) :
with torch . set_grad_enabled ( self . gradient_accumulation_steps > 1 ) :
self . micro_steps + = 1
2023-08-28 16:02:45 +00:00
self . global_samples + = self . batch_size
2023-08-04 01:26:36 +00:00
if ( self . micro_steps + 1 ) % max ( 1 , self . gradient_accumulation_steps ) == 0 :
2024-03-02 02:18:43 +00:00
torch . nn . utils . clip_grad_norm_ ( self . module . parameters ( ) , self . gradient_clipping )
2023-08-04 01:26:36 +00:00
self . global_steps + = 1
2024-05-12 03:23:29 +00:00
if self . loss_scaler is not None :
self . loss_scaler . step ( self . optimizer )
self . loss_scaler . update ( )
else :
self . optimizer . step ( )
2023-08-04 01:26:36 +00:00
self . optimizer . zero_grad ( )
2024-03-02 02:18:43 +00:00
self . _get_grad_norm ( )
def _get_grad_norm ( self ) :
2024-03-02 02:38:06 +00:00
t = [ param . grad . detach ( ) . flatten ( ) for param in self . module . parameters ( ) if param . grad is not None ]
2024-05-11 20:02:47 +00:00
self . _global_grad_norm = torch . cat ( t ) . norm ( ) . item ( ) if len ( t ) else None
2024-03-02 02:18:43 +00:00
2023-08-04 01:26:36 +00:00
def get_lr ( self ) :
lrs = [ ]
for param_group in self . optimizer . param_groups :
2024-02-01 03:48:36 +00:00
if ' d_coeff ' in param_group :
lrs . append ( param_group [ ' d_coeff ' ] )
elif ' lr ' in param_group :
2023-08-04 01:26:36 +00:00
lrs . append ( param_group [ ' lr ' ] )
return lrs
def set_lr ( self , lr ) :
for param_group in self . optimizer . param_groups :
2024-02-01 03:48:36 +00:00
if ' d_coeff ' in param_group :
param_group [ ' d_coeff ' ] = lr
elif ' lr ' in param_group :
2023-08-04 01:26:36 +00:00
param_group [ ' lr ' ] = lr
def get_global_grad_norm ( self ) :
2024-03-02 02:18:43 +00:00
return self . _global_grad_norm
2023-08-04 01:26:36 +00:00
def traverse ( self , * args , * * kwargs ) :
2024-04-09 01:14:51 +00:00
with ml . autocast ( ) :
2023-09-02 01:58:29 +00:00
self . forward ( * args , * * kwargs )
2024-04-09 01:14:51 +00:00
losses = self . gather_attribute ( " loss " )
loss = torch . stack ( [ * losses . values ( ) ] ) . sum ( )
2023-08-04 01:26:36 +00:00
2024-05-12 03:23:29 +00:00
if torch . isnan ( loss ) . any ( ) :
self . max_nan_losses = self . max_nan_losses - 1
if self . max_nan_losses < 0 :
raise RuntimeError ( " Too many NaN losses detected. " )
2023-08-04 01:26:36 +00:00
stats = { }
stats | = { k : v . item ( ) for k , v in losses . items ( ) }
stats | = self . gather_attribute ( " scalar " )
self . backward ( loss )
self . step ( )
return stats
# and now to ignore everything from the above
class Engines ( dict [ str , Engine ] ) :
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
self . setup ( )
def setup ( self ) :
self . _global_step = 0
self . _micro_step = 0
2023-08-27 17:26:12 +00:00
self . _batch_size = 0
2023-09-03 13:03:36 +00:00
self . _global_samples = 0
2023-08-20 18:39:58 +00:00
2023-08-04 01:26:36 +00:00
@property
def global_step ( self ) :
return self . _global_step
@property
def micro_step ( self ) :
return self . _micro_step
2023-08-27 17:26:12 +00:00
@property
def batch_size ( self ) :
return self . _batch_size
2023-09-03 13:03:36 +00:00
@property
def global_samples ( self ) :
return self . _global_samples
2023-08-04 01:26:36 +00:00
def gather_attribute ( self , * args , * * kwargs ) :
ret = { }
for engine in self . values ( ) :
ret | = engine . gather_attribute ( * args , * * kwargs )
return ret
def dispatch_attribute ( self , * args , * * kwargs ) :
for engine in self . values ( ) :
engine . dispatch_attribute ( * args , * * kwargs )
2024-06-04 03:34:47 +00:00
def export ( self , userdata = { } , callback = None ) :
2023-08-20 18:39:58 +00:00
for name , engine in self . items ( ) :
outpath = cfg . ckpt_dir / name / " fp32.pth "
state_dict = {
2023-08-28 16:02:45 +00:00
' module ' : engine . module . state_dict ( ) ,
2023-09-21 00:10:59 +00:00
" stats " : {
" global_step " : engine . global_step ,
" micro_step " : engine . micro_step ,
" global_samples " : engine . global_samples ,
" tokens_processed " : engine . tokens_processed ,
} ,
" userdata " : userdata
2023-08-20 18:39:58 +00:00
}
2024-06-04 03:34:47 +00:00
if callback :
2024-06-06 18:08:02 +00:00
state_dict = callback ( state_dict , engine . hyper_config )
2023-08-20 18:39:58 +00:00
torch . save ( state_dict , outpath )
print ( f " Exported { name } to { outpath } " )
2023-08-04 01:26:36 +00:00
def save_checkpoint ( self , tag = None ) :
if not tag :
tag = cfg . trainer . save_tag
tag = tag . lower ( )
if tag [ : 2 ] == " it " or tag [ : 4 ] == " step " :
tag = f ' { self . global_step } '
cfg . ckpt_dir . mkdir ( parents = True , exist_ok = True )
for name , engine in self . items ( ) :
2024-06-07 02:57:11 +00:00
if not engine . _training :
2023-08-27 17:26:12 +00:00
continue
2023-08-17 01:12:12 +00:00
save_dir = cfg . ckpt_dir / name
2023-08-20 11:29:17 +00:00
try :
engine . save_checkpoint ( save_dir , tag = tag )
except Exception as e :
print ( f ' Failed to save checkpoint for engine { name } : ' , str ( e ) )
2023-08-17 04:37:52 +00:00
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
if cfg . trainer . keep_last_checkpoints > 0 and is_global_leader ( ) :
2023-08-17 05:11:29 +00:00
checkpoints = [ d for d in list ( save_dir . glob ( " * " ) ) if d . is_dir ( ) ]
2023-08-17 01:12:12 +00:00
checkpoints . sort ( key = lambda x : x . stat ( ) . st_mtime )
checkpoints = checkpoints [ : - cfg . trainer . keep_last_checkpoints ]
for d in checkpoints :
2023-08-17 05:11:29 +00:00
if not d . is_dir ( ) or not d . exists ( ) :
continue
print ( " Removing " , d )
2023-08-17 01:12:12 +00:00
for p in d . iterdir ( ) :
p . unlink ( )
d . rmdir ( )
2023-08-04 01:26:36 +00:00
2023-08-20 18:42:18 +00:00
def load_checkpoint ( self , tag = None ) :
2023-08-04 01:26:36 +00:00
if not tag :
tag = cfg . trainer . load_tag
for name , engine in self . items ( ) :
load_dir = cfg . ckpt_dir / name
engine . load_checkpoint (
tag = tag ,
load_dir = load_dir ,
load_module_strict = cfg . trainer . strict_loading ,
2023-08-20 18:42:18 +00:00
load_optimizer_states = False if cfg . trainer . load_module_only else cfg . trainer . load_states ,
load_lr_scheduler_states = False if cfg . trainer . load_module_only else cfg . trainer . load_states ,
load_module_only = cfg . trainer . load_module_only ,
2023-08-04 01:26:36 +00:00
)
if cfg . trainer . restart_step_count :
engine . global_steps = 0
2023-10-29 17:11:19 +00:00
engine . mocro_step = 0
engine . global_samples = 0
engine . tokens_processed = 0
2023-08-04 01:26:36 +00:00
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if cfg . hyperparameters . scheduler_type == " " :
self . set_lr ( cfg . hyperparameters . learning_rate )
2023-08-27 17:26:12 +00:00
self . _update ( )
2023-08-04 01:26:36 +00:00
def set_lr ( self , lr ) :
for engine in self . values ( ) :
2024-06-06 18:08:02 +00:00
if not engine . training :
2023-08-27 17:26:12 +00:00
continue
2023-08-04 01:26:36 +00:00
engine . set_lr ( lr )
2023-08-27 17:26:12 +00:00
def _update ( self ) :
2023-08-04 01:26:36 +00:00
for engine in self . values ( ) :
self . _global_step = max ( self . _global_step , engine . global_step )
self . _micro_step = max ( self . _micro_step , engine . micro_step )
2023-08-27 17:26:12 +00:00
self . _batch_size = max ( self . _batch_size , engine . batch_size )
2023-09-03 13:03:36 +00:00
self . _global_samples = max ( self . _global_samples , engine . global_samples )
2023-08-04 01:26:36 +00:00
def eval ( self ) :
for engine in self . values ( ) :
engine . eval ( )
def train ( self ) :
for engine in self . values ( ) :
engine . train ( )
def traverse ( self ) :
stats = { }
for name , engine in self . items ( ) :
stat = engine . traverse ( )
stats . update ( flatten_dict ( { name . split ( " - " ) [ 0 ] : stat } ) )
return stats
2023-08-14 03:07:45 +00:00
def step ( self , batch , feeder : TrainFeeder = default_feeder ) :
2023-08-04 01:26:36 +00:00
total_elapsed_time = 0
stats : Any = dict ( )
if cfg . trainer . gc_mode == ' step ' :
do_gc ( )
for name , engine in self . items ( ) :
2024-06-06 18:08:02 +00:00
if not engine . training :
2023-08-27 17:26:12 +00:00
continue
2023-08-14 03:07:45 +00:00
device = engine . device
2023-08-04 01:26:36 +00:00
if cfg . trainer . gc_mode == ' substep ' :
do_gc ( )
start_time = time . time ( )
tries = 4
2023-08-14 03:07:45 +00:00
n_ooms = torch . zeros ( [ ] , device = device )
2023-08-04 01:26:36 +00:00
2023-08-14 03:07:45 +00:00
batch = to_device ( batch , device )
2023-08-04 01:26:36 +00:00
2023-08-04 01:36:19 +00:00
if not cfg . trainer . check_for_oom :
res = feeder ( engine = engine , batch = batch )
else :
while tries > = 0 :
try :
res = feeder ( engine = engine , batch = batch )
break
except RuntimeError as e :
print ( " Forward " , str ( e ) )
if " out of memory " not in str ( e ) :
self . save_checkpoint ( )
raise e
# shrink batch size until it's happy
for k in batch :
batch [ k ] = batch [ k ] [ : - 1 ]
if tries < = 0 :
# trigger OOM
n_ooms + = 1
else :
# also do GC
do_gc ( )
continue
2023-08-24 22:05:56 +00:00
if world_size ( ) > 1 :
all_reduce ( n_ooms )
2023-08-04 01:36:19 +00:00
if n_ooms . item ( ) > 0 :
self . save_checkpoint ( )
raise RuntimeError ( " Out of memory during forward pass! " )
2023-08-04 01:26:36 +00:00
if res is None :
continue
loss , engine_stats = res
engine_stats | = self . gather_attribute ( " scalar " )
2023-08-14 03:07:45 +00:00
n_ooms = torch . zeros ( [ ] , device = device )
2023-08-04 01:26:36 +00:00
if cfg . trainer . aggressive_optimizations :
batch = to_device ( batch , ' cpu ' )
2023-08-04 01:36:19 +00:00
if not cfg . trainer . check_for_oom :
2023-08-04 01:26:36 +00:00
engine . backward ( loss )
2023-08-04 01:36:19 +00:00
else :
2024-05-05 02:03:46 +00:00
# to-do: properly handle when one GPU throws an OOM because it just halts
2023-08-04 01:36:19 +00:00
try :
engine . backward ( loss )
except RuntimeError as e :
print ( " Backwards: " , str ( e ) )
if " out of memory " not in str ( e ) :
self . save_checkpoint ( )
raise e
n_ooms + = 1
2023-08-04 01:26:36 +00:00
2023-08-24 22:05:56 +00:00
if world_size ( ) > 1 :
all_reduce ( n_ooms )
2024-05-05 02:03:46 +00:00
2023-08-04 01:36:19 +00:00
if n_ooms . item ( ) > 0 :
2023-08-04 01:26:36 +00:00
self . save_checkpoint ( )
2024-05-05 02:03:46 +00:00
raise RuntimeError ( " Out of memory during backwards pass! " )
2023-08-04 01:26:36 +00:00
engine . step ( )
#torch.cuda.synchronize()
elapsed_time = time . time ( ) - start_time
total_elapsed_time + = elapsed_time
2024-06-01 15:44:32 +00:00
grad_norm = engine . get_global_grad_norm ( )
loss_scale = 1
2024-06-02 13:29:27 +00:00
if hasattr ( engine . optimizer , " loss_scale " ) and engine . optimizer . loss_scale is not None :
2024-06-01 15:44:32 +00:00
loss_scale = engine . optimizer . loss_scale
2024-06-02 13:29:27 +00:00
if grad_norm is not None :
grad_norm / = loss_scale
2023-08-04 01:26:36 +00:00
stats . update (
flatten_dict (
{
name . split ( " - " ) [ 0 ] : dict (
2024-05-11 20:02:47 +00:00
* * engine_stats ,
2023-08-04 01:26:36 +00:00
lr = engine . get_lr ( ) [ 0 ] ,
2024-06-01 15:44:32 +00:00
grad_norm = grad_norm ,
loss_scale = loss_scale if loss_scale != 1 else None ,
2023-08-04 01:26:36 +00:00
elapsed_time = elapsed_time ,
engine_step = engine . global_step ,
2023-08-28 16:02:45 +00:00
samples_processed = engine . global_samples ,
tokens_processed = engine . tokens_processed ,
2023-08-04 01:26:36 +00:00
)
}
) ,
)
2023-08-27 17:26:12 +00:00
self . _update ( )
2023-10-17 00:30:38 +00:00
if len ( self . keys ( ) ) > 1 :
stats [ " elapsed_time " ] = total_elapsed_time
stats [ " it " ] = self . global_step
2023-08-04 01:26:36 +00:00
return stats