@ -309,6 +309,7 @@ def example_usage():
from . . engines import Engine
from tqdm import tqdm
from . . utils import wrapper as ml
import re
device = " cuda "
x8 = partial ( repeat , pattern = " t -> t l " , l = cfg . models . prom_levels )
@ -367,6 +368,30 @@ def example_usage():
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine ( model = model , optimizer = optimizer )
# copy embeddings if requested
if cfg . models . _embeddings is not None :
embeddings_path = cfg . relpath / cfg . models . _embeddings
if embeddings_path . exists ( ) :
embeddings = torch . load ( embeddings_path , map_location = torch . device ( cfg . device ) )
if " module " in embeddings :
embeddings = embeddings [ " module " ]
frozen_params = set ( )
for k in list ( embeddings . keys ( ) ) :
if re . findall ( r ' _emb \ . ' , k ) :
frozen_params . add ( k )
else :
del embeddings [ k ]
engine . module . load_state_dict ( embeddings , strict = False )
for name , param in engine . module . named_parameters ( ) :
if name not in frozen_params :
continue
param . requires_grad_ ( False )
engine . _frozen_params . add ( param )
if cfg . bitsandbytes . enabled and cfg . bitsandbytes . replace :
model . model = ml . replace_linear ( model . model )