it actually wasn't working because Engines.__init__() automatically moves the entire module to the requested device, which was being called after offloading the model in the test trainer (and it seems I cant do it without injecting a bunch of shit in modeling_llama.py)
This commit is contained in:
parent
b4c895114c
commit
c9ec6b28ef
|
@ -683,6 +683,7 @@ class Optimizations:
|
|||
model_offloading: dict | None = None # automatically splits the model over a list of devices
|
||||
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
|
||||
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
|
||||
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
|
||||
|
||||
@dataclass()
|
||||
class Config(BaseConfig):
|
||||
|
|
|
@ -526,18 +526,19 @@ def example_usage():
|
|||
"""
|
||||
cfg.optimizations.model_offloading = {
|
||||
"devices": ["cuda:0", "cpu"],
|
||||
"limits": [ 0.5, -1 ]
|
||||
# "limits": [ 256 * (1024 ** 2), -1 ]
|
||||
# "limits": [ 0.9, -1 ],
|
||||
"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]],
|
||||
"limits": [ 256 * (1024 ** 2), -1 ]
|
||||
}
|
||||
"""
|
||||
if cfg.optimizations.model_offloading:
|
||||
model = ml.offload_model( model, policy=cfg.optimizations.model_offloading )
|
||||
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
engines = Engines({"ar+nar": engine})
|
||||
engines.setup()
|
||||
|
||||
if cfg.optimizations.model_offloading:
|
||||
model = ml.offload_model( model, policy=cfg.optimizations.model_offloading )
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
|
|
|
@ -394,12 +394,18 @@ def get_model_offload_policy(module, policy=None):
|
|||
# default to only include the core model, and not the other modules (embeddings) in the splitting policy
|
||||
if "include" not in policy:
|
||||
policy["include"] = ["model"]
|
||||
|
||||
if "limits" not in policy:
|
||||
policy["limits"] = []
|
||||
|
||||
if "assign" not in policy:
|
||||
policy["assign"] = []
|
||||
|
||||
if "devices" not in policy:
|
||||
policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget
|
||||
|
||||
print( policy )
|
||||
|
||||
# create initial device info
|
||||
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
|
||||
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
|
||||
|
@ -422,8 +428,42 @@ def get_model_offload_policy(module, policy=None):
|
|||
# cap to requested size
|
||||
devices[i]["free"] = cap
|
||||
|
||||
# assign if specific parts of the model are requested for assignment
|
||||
if policy["assign"]:
|
||||
discarded = []
|
||||
# yuck, there has to be a better way
|
||||
for device_index, includes in enumerate( policy["assign"] ):
|
||||
device = devices[device_index]
|
||||
|
||||
buffered_modules = []
|
||||
buffered_size = device["free"]
|
||||
|
||||
# iterate through list of modules to compare against includes
|
||||
for name, size in modules:
|
||||
# doesn't pass policy
|
||||
if not passes_policy( {"include": includes}, name ):
|
||||
continue
|
||||
# check if within budget
|
||||
if buffered_size - size >= 0:
|
||||
# add to buffer
|
||||
buffered_modules.append( name )
|
||||
buffered_size -= size
|
||||
# budget exceeded, flush buffer
|
||||
else:
|
||||
discarded += buffered_modules
|
||||
buffered_modules = []
|
||||
buffered_size = 0
|
||||
break
|
||||
|
||||
if buffered_modules and buffered_size:
|
||||
device["modules"] += buffered_modules
|
||||
device["free"] = buffered_size
|
||||
|
||||
modules = discarded
|
||||
|
||||
device_index = 0
|
||||
module_index = 0
|
||||
# assign modules to each device
|
||||
while module_index < len(modules) and device_index < len(devices):
|
||||
device = devices[device_index]
|
||||
name, size = modules[module_index]
|
||||
|
|
Loading…
Reference in New Issue
Block a user