2023-08-04 01:26:36 +00:00
from torch import Tensor
from typing import Any , Protocol
Stats = dict [ str , float ]
class TrainFeeder ( Protocol ) :
def __call__ (
self , * , engine : " Engine " , batch : Any
) - > None | tuple [ Tensor , Stats ] :
. . .
2023-08-04 02:39:00 +00:00
def default_feeder ( engine , batch ) :
2023-08-04 01:26:36 +00:00
if isinstance ( batch , list ) :
engine ( * batch )
elif isinstance ( batch , dict ) :
engine ( * * batch )
else :
engine ( batch )
losses = engine . gather_attribute ( " loss " )
loss = torch . stack ( [ * losses . values ( ) ] ) . sum ( )
stats = { }
stats | = { k : v . item ( ) for k , v in losses . items ( ) }
return loss , stats
from . . config import cfg
from . . utils import dispatch_attribute , flatten_dict , gather_attribute , do_gc , to_device
2023-08-17 04:37:52 +00:00
from . . utils . distributed import init_distributed , distributed_initialized , is_global_leader
2023-08-04 01:26:36 +00:00
import logging
import time
import torch
import torch . distributed
import os
from torch import Tensor
from torch . distributed import all_reduce
from typing import Any , Protocol
from . base import TrainFeeder
_logger = logging . getLogger ( __name__ )
2023-08-05 03:22:15 +00:00
if not distributed_initialized ( ) and cfg . trainer . backend == " local " :
init_distributed ( torch . distributed . init_process_group )
2023-08-04 01:26:36 +00:00
# A very naive engine implementation using barebones PyTorch
class Engine ( ) :
def __init__ ( self , * args , * * kwargs ) :
2023-08-05 03:22:15 +00:00
self . module = kwargs [ ' model ' ] . to ( cfg . device ) . to ( cfg . trainer . dtype )
2023-08-04 01:26:36 +00:00
self . optimizer = kwargs [ ' optimizer ' ] if ' optimizer ' in kwargs else None
self . lr_scheduler = kwargs [ ' lr_scheduler ' ] if ' lr_scheduler ' in kwargs else None
self . global_steps = 0
self . micro_steps = 0
self . gradient_accumulation_steps = cfg . hyperparameters . gradient_accumulation_steps
def freeze ( self ) :
for p in self . module . parameters ( ) :
if p . requires_grad :
p . requires_grad_ ( False )
self . _frozen_params . add ( p )
def unfreeze ( self ) :
for p in self . _frozen_params :
p . requires_grad_ ( True )
self . _frozen_params . clear ( )
@property
def global_step ( self ) :
return self . global_steps
@property
def micro_step ( self ) :
return self . micro_steps
def train_batch_size ( self ) :
return cfg . hyperparameters . batch_size
def gather_attribute ( self , * args , * * kwargs ) :
return gather_attribute ( self . module , * args , * * kwargs )
def dispatch_attribute ( self , * args , * * kwargs ) :
return dispatch_attribute ( self . module , * args , * * kwargs )
def save_checkpoint ( self , save_dir , tag ) :
2023-08-05 02:17:30 +00:00
save_path = save_dir / tag / " state.pth "
save_path . parent . mkdir ( parents = True , exist_ok = True )
2023-08-04 01:26:36 +00:00
torch . save ( {
" global_step " : self . global_step ,
" micro_step " : self . micro_step ,
" module " : self . module . state_dict ( ) ,
" optimizer " : self . optimizer . state_dict ( ) if self . optimizer is not None else None ,
" lr_scheduler " : self . lr_scheduler . state_dict ( ) if self . lr_scheduler is not None else None ,
2023-08-05 02:17:30 +00:00
} , save_path )
2023-08-04 01:26:36 +00:00
2023-08-05 03:22:15 +00:00
open ( save_dir / " latest " , ' w ' ) . write ( tag )
2023-08-05 02:17:30 +00:00
def load_checkpoint ( self , load_dir , tag = None , load_module_strict = True , load_optimizer_states = True , load_lr_scheduler_states = True ) :
if tag is None :
tag_path = load_dir / " latest "
if not tag_path . exists ( ) :
return
tag = open ( tag_path ) . read ( )
load_path = load_dir / tag / " state.pth "
if not load_path . exists ( ) :
return
state = torch . load ( load_path )
2023-08-05 03:22:15 +00:00
self . global_steps = state [ ' global_step ' ]
self . micro_steps = state [ ' micro_step ' ]
2023-08-04 01:26:36 +00:00
self . module . load_state_dict ( state [ ' module ' ] )
load_optimizer_states = load_optimizer_states and self . optimizer is not None and ' optimizer ' in state
load_lr_scheduler_states = load_lr_scheduler_states and self . lr_scheduler is not None and ' lr_scheduler ' in state
if load_optimizer_states :
self . optimizer . load_state_dict ( state [ ' optimizer ' ] )
if load_lr_scheduler_states :
self . lr_scheduler . load_state_dict ( state [ ' lr_scheduler ' ] )
def eval ( self ) :
return self . module . eval ( )
def train ( self ) :
return self . module . train ( )
def to ( self , * args , * * kwargs ) :
self . module = self . module . to ( * args , * * kwargs )
return self . module
def __call__ ( self , * args , * * kwargs ) :
return self . forward ( * args , * * kwargs )
def forward ( self , * args , * * kwargs ) :
return self . module . forward ( * args , * * kwargs )
def backward ( self , loss ) :
return ( loss / self . gradient_accumulation_steps ) . backward ( )
def step ( self ) :
with torch . set_grad_enabled ( self . gradient_accumulation_steps > 1 ) :
self . micro_steps + = 1
if ( self . micro_steps + 1 ) % max ( 1 , self . gradient_accumulation_steps ) == 0 :
self . global_steps + = 1
self . optimizer . step ( )
self . optimizer . zero_grad ( )
def get_lr ( self ) :
lrs = [ ]
for param_group in self . optimizer . param_groups :
if ' lr ' in param_group :
lrs . append ( param_group [ ' lr ' ] )
return lrs
def set_lr ( self , lr ) :
for param_group in self . optimizer . param_groups :
if ' lr ' in param_group :
param_group [ ' lr ' ] = lr
def get_global_grad_norm ( self ) :
return 0.0
def traverse ( self , * args , * * kwargs ) :
self . forward ( * args , * * kwargs )
losses = self . gather_attribute ( " loss " )
loss = torch . stack ( [ * losses . values ( ) ] ) . sum ( )
stats = { }
stats | = { k : v . item ( ) for k , v in losses . items ( ) }
stats | = self . gather_attribute ( " scalar " )
self . backward ( loss )
self . step ( )
return stats
# and now to ignore everything from the above
class Engines ( dict [ str , Engine ] ) :
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
self . setup ( )
def setup ( self ) :
self . _global_step = 0
self . _micro_step = 0
@property
def global_step ( self ) :
return self . _global_step
@property
def micro_step ( self ) :
return self . _micro_step
def gather_attribute ( self , * args , * * kwargs ) :
ret = { }
for engine in self . values ( ) :
ret | = engine . gather_attribute ( * args , * * kwargs )
return ret
def dispatch_attribute ( self , * args , * * kwargs ) :
for engine in self . values ( ) :
engine . dispatch_attribute ( * args , * * kwargs )
def save_checkpoint ( self , tag = None ) :
if not tag :
tag = cfg . trainer . save_tag
tag = tag . lower ( )
if tag [ : 2 ] == " it " or tag [ : 4 ] == " step " :
tag = f ' { self . global_step } '
cfg . ckpt_dir . mkdir ( parents = True , exist_ok = True )
for name , engine in self . items ( ) :
2023-08-17 01:12:12 +00:00
save_dir = cfg . ckpt_dir / name
engine . save_checkpoint ( save_dir , tag = tag )
2023-08-17 04:37:52 +00:00
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
if cfg . trainer . keep_last_checkpoints > 0 and is_global_leader ( ) :
2023-08-17 05:11:29 +00:00
checkpoints = [ d for d in list ( save_dir . glob ( " * " ) ) if d . is_dir ( ) ]
2023-08-17 01:12:12 +00:00
checkpoints . sort ( key = lambda x : x . stat ( ) . st_mtime )
checkpoints = checkpoints [ : - cfg . trainer . keep_last_checkpoints ]
for d in checkpoints :
2023-08-17 05:11:29 +00:00
if not d . is_dir ( ) or not d . exists ( ) :
continue
print ( " Removing " , d )
2023-08-17 01:12:12 +00:00
for p in d . iterdir ( ) :
p . unlink ( )
d . rmdir ( )
2023-08-04 01:26:36 +00:00
def load_checkpoint ( self , tag = None ) :
if not tag :
tag = cfg . trainer . load_tag
for name , engine in self . items ( ) :
load_dir = cfg . ckpt_dir / name
engine . load_checkpoint (
tag = tag ,
load_dir = load_dir ,
load_module_strict = cfg . trainer . strict_loading ,
load_optimizer_states = cfg . trainer . load_states ,
load_lr_scheduler_states = cfg . trainer . load_states ,
)
if cfg . trainer . restart_step_count :
engine . global_steps = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if cfg . hyperparameters . scheduler_type == " " :
self . set_lr ( cfg . hyperparameters . learning_rate )
self . _update_global_step ( )
self . _update_micro_step ( )
def set_lr ( self , lr ) :
for engine in self . values ( ) :
engine . set_lr ( lr )
def _update_global_step ( self ) :
for engine in self . values ( ) :
self . _global_step = max ( self . _global_step , engine . global_step )
def _update_micro_step ( self ) :
for engine in self . values ( ) :
self . _micro_step = max ( self . _micro_step , engine . micro_step )
def train_batch_size ( self ) :
batch_size = 0
for engine in self . values ( ) :
batch_size = max ( batch_size , engine . train_batch_size ( ) )
def eval ( self ) :
for engine in self . values ( ) :
engine . eval ( )
def train ( self ) :
for engine in self . values ( ) :
engine . train ( )
def traverse ( self ) :
stats = { }
for name , engine in self . items ( ) :
stat = engine . traverse ( )
stats . update ( flatten_dict ( { name . split ( " - " ) [ 0 ] : stat } ) )
return stats
2023-08-14 03:07:45 +00:00
def step ( self , batch , feeder : TrainFeeder = default_feeder ) :
2023-08-04 01:26:36 +00:00
total_elapsed_time = 0
stats : Any = dict ( )
if cfg . trainer . gc_mode == ' step ' :
do_gc ( )
for name , engine in self . items ( ) :
2023-08-14 03:07:45 +00:00
device = engine . device
2023-08-04 01:26:36 +00:00
if cfg . trainer . gc_mode == ' substep ' :
do_gc ( )
start_time = time . time ( )
tries = 4
2023-08-14 03:07:45 +00:00
n_ooms = torch . zeros ( [ ] , device = device )
2023-08-04 01:26:36 +00:00
2023-08-14 03:07:45 +00:00
batch = to_device ( batch , device )
2023-08-04 01:26:36 +00:00
2023-08-04 01:36:19 +00:00
if not cfg . trainer . check_for_oom :
res = feeder ( engine = engine , batch = batch )
else :
while tries > = 0 :
try :
res = feeder ( engine = engine , batch = batch )
break
except RuntimeError as e :
print ( " Forward " , str ( e ) )
if " out of memory " not in str ( e ) :
self . save_checkpoint ( )
raise e
# shrink batch size until it's happy
for k in batch :
batch [ k ] = batch [ k ] [ : - 1 ]
if tries < = 0 :
# trigger OOM
n_ooms + = 1
else :
# also do GC
do_gc ( )
continue
all_reduce ( n_ooms )
if n_ooms . item ( ) > 0 :
self . save_checkpoint ( )
raise RuntimeError ( " Out of memory during forward pass! " )
2023-08-04 01:26:36 +00:00
if res is None :
continue
loss , engine_stats = res
engine_stats | = self . gather_attribute ( " scalar " )
2023-08-14 03:07:45 +00:00
n_ooms = torch . zeros ( [ ] , device = device )
2023-08-04 01:26:36 +00:00
if cfg . trainer . aggressive_optimizations :
batch = to_device ( batch , ' cpu ' )
2023-08-04 01:36:19 +00:00
if not cfg . trainer . check_for_oom :
2023-08-04 01:26:36 +00:00
engine . backward ( loss )
2023-08-04 01:36:19 +00:00
else :
try :
engine . backward ( loss )
except RuntimeError as e :
print ( " Backwards: " , str ( e ) )
if " out of memory " not in str ( e ) :
self . save_checkpoint ( )
raise e
n_ooms + = 1
2023-08-04 01:26:36 +00:00
2023-08-04 01:36:19 +00:00
all_reduce ( n_ooms )
if n_ooms . item ( ) > 0 :
2023-08-04 01:26:36 +00:00
self . save_checkpoint ( )
2023-08-04 01:36:19 +00:00
raise RuntimeError ( " Out of memory during backwards pass! " )
2023-08-04 01:26:36 +00:00
engine . step ( )
#torch.cuda.synchronize()
elapsed_time = time . time ( ) - start_time
total_elapsed_time + = elapsed_time
stats . update (
flatten_dict (
{
name . split ( " - " ) [ 0 ] : dict (
loss = loss . item ( ) ,
lr = engine . get_lr ( ) [ 0 ] ,
grad_norm = engine . get_global_grad_norm ( ) , # This norm is delayed but global and avoids extra computation
elapsed_time = elapsed_time ,
engine_step = engine . global_step ,
* * engine_stats ,
)
}
) ,
)
self . _update_global_step ( )
self . _update_micro_step ( )
stats [ " batch_size " ] = self . train_batch_size ( ) # len(batch["text"])
stats [ " elapsed_time " ] = total_elapsed_time
stats [ " wall_time " ] = time . time ( )
stats [ " global_step " ] = self . global_step
return stats