@ -83,31 +83,27 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
torch . optim . AdamW = AdamW
torch . optim . SGD = SGD
# disgusting kludge, but it works
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
def replace_linear ( model ) :
device = next ( model . parameters ( ) ) . device
linears = [ k . split ( ' . ' ) for k , m in model . named_modules ( ) if type( m ) . __name__ == ' Linear ' ]
linears = [ k . split ( ' . ' ) for k , m in model . named_modules ( ) if isinstance( m , torch . nn . Linear ) ]
for * parent , k in linears :
name = ' . ' . join ( parent )
# copy parameters
m = getattr (
model . get_submodule ( name ) ,
k
)
m = getattr ( model . get_submodule ( name ) , k )
in_features = m . in_features
out_features = m . out_features
bias = False if cfg . bitsandbytes . bitnet else m . bias # errors out with BitNet
bias = m . bias is not None
# overwri gh t
# overwri te
setattr (
model . get_submodule ( name ) ,
k ,
Linear ( in_features = in_features , out_features = out_features , bias = bias ) . to ( device )
model . get_submodule ( name ) , k ,
Linear ( in_features = in_features , out_features = out_features , bias = bias )
)
return model
return model . to ( device ) # because our now Linear is created on the CPU......
# https://github.com/konstmish/prodigy
try :