ugh, finally got some form of offloading working (need to test if it works on different GPUs, but GPU and CPU offloading seems to work in the test trainer)
This commit is contained in:
parent
c9ec6b28ef
commit
443422ecb5
|
@ -527,8 +527,8 @@ def example_usage():
|
||||||
cfg.optimizations.model_offloading = {
|
cfg.optimizations.model_offloading = {
|
||||||
"devices": ["cuda:0", "cpu"],
|
"devices": ["cuda:0", "cpu"],
|
||||||
# "limits": [ 0.9, -1 ],
|
# "limits": [ 0.9, -1 ],
|
||||||
"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]],
|
"assign": [[ f'layers.{i}.' for i in range(0,10) ], [ f'layers.{i}.' for i in range(11,12) ] + [ "model.norm" ]],
|
||||||
"limits": [ 256 * (1024 ** 2), -1 ]
|
# "limits": [ 256 * (1024 ** 2), -1 ]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -229,15 +229,22 @@ def autocast_forward( func ):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
# handles migrating an input tensor to a given devicve
|
# handles migrating an input tensor to a given devicve
|
||||||
def auto_to_forward( module, device=None ):
|
def auto_align_inputs_forward( module, device=None, name = None ):
|
||||||
if device is None:
|
|
||||||
device = next(module.parameters()).device
|
|
||||||
|
|
||||||
func = module.forward
|
func = module.forward
|
||||||
|
|
||||||
def wrapper( self, *args, **kwargs ):
|
if device is None:
|
||||||
# search through args and kwargs for any Tensor arguments
|
if hasattr( module, 'device' ):
|
||||||
|
device = module.device
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
device = next(module.parameters() if [*module.parameters()] else module.buffers()).device
|
||||||
|
except Exception as e:
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def wrapper( *args, **kwargs ):
|
||||||
args = [*args]
|
args = [*args]
|
||||||
|
# search through args and kwargs for any Tensor arguments
|
||||||
for i, arg in enumerate(args):
|
for i, arg in enumerate(args):
|
||||||
if not isinstance( arg, torch.Tensor ):
|
if not isinstance( arg, torch.Tensor ):
|
||||||
continue
|
continue
|
||||||
|
@ -248,7 +255,11 @@ def auto_to_forward( module, device=None ):
|
||||||
continue
|
continue
|
||||||
kwargs[k] = v.to( device=device )
|
kwargs[k] = v.to( device=device )
|
||||||
|
|
||||||
return func( self, *args, **kwargs )
|
# disgusting patch
|
||||||
|
if "position_embeddings" in kwargs:
|
||||||
|
kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ])
|
||||||
|
|
||||||
|
return func( *args, **kwargs )
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||||
|
@ -404,8 +415,6 @@ def get_model_offload_policy(module, policy=None):
|
||||||
if "devices" not in policy:
|
if "devices" not in policy:
|
||||||
policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget
|
policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget
|
||||||
|
|
||||||
print( policy )
|
|
||||||
|
|
||||||
# create initial device info
|
# create initial device info
|
||||||
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
|
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
|
||||||
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
|
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
|
||||||
|
@ -498,23 +507,14 @@ def offload_model( model, policy=None ):
|
||||||
for name in device["modules"]:
|
for name in device["modules"]:
|
||||||
module = model.get_submodule(name)
|
module = model.get_submodule(name)
|
||||||
module = module.to( device["name"] )
|
module = module.to( device["name"] )
|
||||||
|
module.device = device['name']
|
||||||
|
|
||||||
"""
|
# wrap modules with forward to ensure all inputs are matched to its device
|
||||||
# in case the above doesn't actually do what's requested
|
for name, module in model.named_modules():
|
||||||
*parent, key = name.split(".")
|
if not hasattr( module, 'forward' ):
|
||||||
module = getattr( model.get_submodule(name), key )
|
|
||||||
setattr( model.get_submodule(name), key, module.to( device["name"] ) )
|
|
||||||
"""
|
|
||||||
|
|
||||||
# select next device to cast inputs to, or wrap to first if last device
|
|
||||||
next_device = policy[i + 1]["name"] if i + 1 < len( policy ) else policy[0]["name"]
|
|
||||||
# same device, don't bother wrapping
|
|
||||||
if device["name"] == next_device:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# wrap forward call
|
module.forward = auto_align_inputs_forward(module)
|
||||||
last_module = model.get_submodule( device["modules"][-1] )
|
|
||||||
last_module.forward = auto_to_forward(last_module, next_device)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Validate that the layers are all in the right spot
|
# Validate that the layers are all in the right spot
|
||||||
|
@ -524,6 +524,7 @@ def offload_model( model, policy=None ):
|
||||||
try:
|
try:
|
||||||
print( name, next(module.parameters()).device )
|
print( name, next(module.parameters()).device )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print( name, "?" )
|
||||||
pass
|
pass
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -99,18 +99,17 @@ except Exception as e:
|
||||||
# backwards compat
|
# backwards compat
|
||||||
from .utils import (
|
from .utils import (
|
||||||
autocast_forward,
|
autocast_forward,
|
||||||
auto_to_forward,
|
replace_linear as replace_linear_old,
|
||||||
replace_linear,
|
|
||||||
replace_embedding as replace_embedding_old,
|
replace_embedding as replace_embedding_old,
|
||||||
replace_attention as replace_attention_old,
|
replace_attention,
|
||||||
resize_weight,
|
resize_weight,
|
||||||
offload_model,
|
offload_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# wrapped here so we can maintain default args
|
# wrapped here so we can maintain default args
|
||||||
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
|
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
|
||||||
return replace_embedding_old( model, klass, target, verbose )
|
return replace_linear_old( model, klass, target, verbose )
|
||||||
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
|
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
|
||||||
return replace_attention_old( model, klass, target, verbose )
|
return replace_embedding_old( model, klass, target, verbose )
|
||||||
|
|
||||||
Embedding.forward = autocast_forward(Embedding.forward)
|
Embedding.forward = autocast_forward(Embedding.forward)
|
Loading…
Reference in New Issue
Block a user