ugh, finally got some form of offloading working (need to test if it works on different GPUs, but GPU and CPU offloading seems to work in the test trainer)

This commit is contained in:
mrq 2024-08-01 22:43:39 -05:00
parent c9ec6b28ef
commit 443422ecb5
3 changed files with 30 additions and 30 deletions

View File

@ -527,8 +527,8 @@ def example_usage():
cfg.optimizations.model_offloading = {
"devices": ["cuda:0", "cpu"],
# "limits": [ 0.9, -1 ],
"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]],
"limits": [ 256 * (1024 ** 2), -1 ]
"assign": [[ f'layers.{i}.' for i in range(0,10) ], [ f'layers.{i}.' for i in range(11,12) ] + [ "model.norm" ]],
# "limits": [ 256 * (1024 ** 2), -1 ]
}
"""

View File

@ -229,15 +229,22 @@ def autocast_forward( func ):
return wrapper
# handles migrating an input tensor to a given devicve
def auto_to_forward( module, device=None ):
if device is None:
device = next(module.parameters()).device
def auto_align_inputs_forward( module, device=None, name = None ):
func = module.forward
def wrapper( self, *args, **kwargs ):
# search through args and kwargs for any Tensor arguments
if device is None:
if hasattr( module, 'device' ):
device = module.device
else:
try:
device = next(module.parameters() if [*module.parameters()] else module.buffers()).device
except Exception as e:
return func
def wrapper( *args, **kwargs ):
args = [*args]
# search through args and kwargs for any Tensor arguments
for i, arg in enumerate(args):
if not isinstance( arg, torch.Tensor ):
continue
@ -248,7 +255,11 @@ def auto_to_forward( module, device=None ):
continue
kwargs[k] = v.to( device=device )
return func( self, *args, **kwargs )
# disgusting patch
if "position_embeddings" in kwargs:
kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ])
return func( *args, **kwargs )
return wrapper
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
@ -404,8 +415,6 @@ def get_model_offload_policy(module, policy=None):
if "devices" not in policy:
policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget
print( policy )
# create initial device info
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
@ -498,23 +507,14 @@ def offload_model( model, policy=None ):
for name in device["modules"]:
module = model.get_submodule(name)
module = module.to( device["name"] )
module.device = device['name']
"""
# in case the above doesn't actually do what's requested
*parent, key = name.split(".")
module = getattr( model.get_submodule(name), key )
setattr( model.get_submodule(name), key, module.to( device["name"] ) )
"""
# select next device to cast inputs to, or wrap to first if last device
next_device = policy[i + 1]["name"] if i + 1 < len( policy ) else policy[0]["name"]
# same device, don't bother wrapping
if device["name"] == next_device:
# wrap modules with forward to ensure all inputs are matched to its device
for name, module in model.named_modules():
if not hasattr( module, 'forward' ):
continue
# wrap forward call
last_module = model.get_submodule( device["modules"][-1] )
last_module.forward = auto_to_forward(last_module, next_device)
module.forward = auto_align_inputs_forward(module)
"""
# Validate that the layers are all in the right spot
@ -524,6 +524,7 @@ def offload_model( model, policy=None ):
try:
print( name, next(module.parameters()).device )
except Exception as e:
print( name, "?" )
pass
"""

View File

@ -99,18 +99,17 @@ except Exception as e:
# backwards compat
from .utils import (
autocast_forward,
auto_to_forward,
replace_linear,
replace_linear as replace_linear_old,
replace_embedding as replace_embedding_old,
replace_attention as replace_attention_old,
replace_attention,
resize_weight,
offload_model,
)
# wrapped here so we can maintain default args
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
return replace_embedding_old( model, klass, target, verbose )
return replace_linear_old( model, klass, target, verbose )
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
return replace_attention_old( model, klass, target, verbose )
return replace_embedding_old( model, klass, target, verbose )
Embedding.forward = autocast_forward(Embedding.forward)