ugh, finally got some form of offloading working (need to test if it works on different GPUs, but GPU and CPU offloading seems to work in the test trainer)
This commit is contained in:
parent
c9ec6b28ef
commit
443422ecb5
|
@ -527,8 +527,8 @@ def example_usage():
|
|||
cfg.optimizations.model_offloading = {
|
||||
"devices": ["cuda:0", "cpu"],
|
||||
# "limits": [ 0.9, -1 ],
|
||||
"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]],
|
||||
"limits": [ 256 * (1024 ** 2), -1 ]
|
||||
"assign": [[ f'layers.{i}.' for i in range(0,10) ], [ f'layers.{i}.' for i in range(11,12) ] + [ "model.norm" ]],
|
||||
# "limits": [ 256 * (1024 ** 2), -1 ]
|
||||
}
|
||||
"""
|
||||
|
||||
|
|
|
@ -229,15 +229,22 @@ def autocast_forward( func ):
|
|||
return wrapper
|
||||
|
||||
# handles migrating an input tensor to a given devicve
|
||||
def auto_to_forward( module, device=None ):
|
||||
if device is None:
|
||||
device = next(module.parameters()).device
|
||||
|
||||
def auto_align_inputs_forward( module, device=None, name = None ):
|
||||
func = module.forward
|
||||
|
||||
def wrapper( self, *args, **kwargs ):
|
||||
# search through args and kwargs for any Tensor arguments
|
||||
if device is None:
|
||||
if hasattr( module, 'device' ):
|
||||
device = module.device
|
||||
else:
|
||||
try:
|
||||
device = next(module.parameters() if [*module.parameters()] else module.buffers()).device
|
||||
except Exception as e:
|
||||
return func
|
||||
|
||||
|
||||
def wrapper( *args, **kwargs ):
|
||||
args = [*args]
|
||||
# search through args and kwargs for any Tensor arguments
|
||||
for i, arg in enumerate(args):
|
||||
if not isinstance( arg, torch.Tensor ):
|
||||
continue
|
||||
|
@ -248,7 +255,11 @@ def auto_to_forward( module, device=None ):
|
|||
continue
|
||||
kwargs[k] = v.to( device=device )
|
||||
|
||||
return func( self, *args, **kwargs )
|
||||
# disgusting patch
|
||||
if "position_embeddings" in kwargs:
|
||||
kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ])
|
||||
|
||||
return func( *args, **kwargs )
|
||||
return wrapper
|
||||
|
||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||
|
@ -404,8 +415,6 @@ def get_model_offload_policy(module, policy=None):
|
|||
if "devices" not in policy:
|
||||
policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget
|
||||
|
||||
print( policy )
|
||||
|
||||
# create initial device info
|
||||
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
|
||||
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
|
||||
|
@ -498,23 +507,14 @@ def offload_model( model, policy=None ):
|
|||
for name in device["modules"]:
|
||||
module = model.get_submodule(name)
|
||||
module = module.to( device["name"] )
|
||||
module.device = device['name']
|
||||
|
||||
"""
|
||||
# in case the above doesn't actually do what's requested
|
||||
*parent, key = name.split(".")
|
||||
module = getattr( model.get_submodule(name), key )
|
||||
setattr( model.get_submodule(name), key, module.to( device["name"] ) )
|
||||
"""
|
||||
|
||||
# select next device to cast inputs to, or wrap to first if last device
|
||||
next_device = policy[i + 1]["name"] if i + 1 < len( policy ) else policy[0]["name"]
|
||||
# same device, don't bother wrapping
|
||||
if device["name"] == next_device:
|
||||
# wrap modules with forward to ensure all inputs are matched to its device
|
||||
for name, module in model.named_modules():
|
||||
if not hasattr( module, 'forward' ):
|
||||
continue
|
||||
|
||||
# wrap forward call
|
||||
last_module = model.get_submodule( device["modules"][-1] )
|
||||
last_module.forward = auto_to_forward(last_module, next_device)
|
||||
module.forward = auto_align_inputs_forward(module)
|
||||
|
||||
"""
|
||||
# Validate that the layers are all in the right spot
|
||||
|
@ -524,6 +524,7 @@ def offload_model( model, policy=None ):
|
|||
try:
|
||||
print( name, next(module.parameters()).device )
|
||||
except Exception as e:
|
||||
print( name, "?" )
|
||||
pass
|
||||
"""
|
||||
|
||||
|
|
|
@ -99,18 +99,17 @@ except Exception as e:
|
|||
# backwards compat
|
||||
from .utils import (
|
||||
autocast_forward,
|
||||
auto_to_forward,
|
||||
replace_linear,
|
||||
replace_linear as replace_linear_old,
|
||||
replace_embedding as replace_embedding_old,
|
||||
replace_attention as replace_attention_old,
|
||||
replace_attention,
|
||||
resize_weight,
|
||||
offload_model,
|
||||
)
|
||||
|
||||
# wrapped here so we can maintain default args
|
||||
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
|
||||
return replace_embedding_old( model, klass, target, verbose )
|
||||
return replace_linear_old( model, klass, target, verbose )
|
||||
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
|
||||
return replace_attention_old( model, klass, target, verbose )
|
||||
return replace_embedding_old( model, klass, target, verbose )
|
||||
|
||||
Embedding.forward = autocast_forward(Embedding.forward)
|
Loading…
Reference in New Issue
Block a user