naive model offloading support (handles automatically splitting parts of the model to requested device per memory constraints, either inferred or requested in the yaml, input tensors are automatically migrated to the right device, it SEEMS to work for training under the test trainer when split between GPU and CPU) (this was specifically only because that Flux imagegen model released so I can test it there)

This commit is contained in:
mrq 2024-08-01 20:12:06 -05:00
parent 387358bc8a
commit b4c895114c
8 changed files with 351 additions and 164 deletions

View File

@ -680,6 +680,10 @@ class Optimizations:
bitnet: bool = False # use bitnet
fp8: bool = False # use fp8
model_offloading: dict | None = None # automatically splits the model over a list of devices
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
@dataclass()
class Config(BaseConfig):
device: str = "cuda" # target device

View File

@ -235,4 +235,8 @@ def load_engines(training=True):
for name, engine in engines.items():
engine.freeze(freeze_all=False)
# split models over requested devices
if cfg.optimizations.model_offloading:
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
return engines

View File

@ -522,6 +522,16 @@ def example_usage():
if cfg.optimizations.replace and cfg.optimizations.embedding:
model = ml.replace_embedding( model )
"""
cfg.optimizations.model_offloading = {
"devices": ["cuda:0", "cpu"],
"limits": [ 0.5, -1 ]
# "limits": [ 256 * (1024 ** 2), -1 ]
}
"""
if cfg.optimizations.model_offloading:
model = ml.offload_model( model, policy=cfg.optimizations.model_offloading )
engine = Engine(model=model, optimizer=optimizer)

View File

@ -250,6 +250,7 @@ class AudioClassifier(nn.Module):
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
# pad if needed
# to-do: validate that this causes ZERO issues
max_size = max([ x.shape[-1] for x in xi ])
xi = [
#x if l == 0 else

View File

@ -11,6 +11,8 @@ from torch import Tensor, nn
import math
from typing import Optional, List
from ..utils import passes_policy
# LoRA Linear for replacement
# Pros: simple, just needs to reuse the replace_linear and copy weights
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
@ -144,22 +146,6 @@ class ParameterizedLoRA(nn.Module):
# M$'s LoRA class arranges things to where this isn't necessary
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
def passes_policy( policy, name ):
if policy is None:
return True
if "exclude" in policy:
for term in policy["exclude"]:
if term in name:
return False
if "include" in policy:
for term in policy["include"]:
if term in name:
return True
return False
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype

View File

@ -8,4 +8,5 @@ from .utils import (
tree_map,
do_gc,
set_seed,
passes_policy
)

View File

@ -12,6 +12,8 @@ import re
import torch
import random
import time
import psutil
import math
from coloredlogs import ColoredFormatter
from logging import StreamHandler
@ -19,7 +21,7 @@ from pathlib import Path
from torch import Tensor, nn
from tqdm.auto import tqdm
from typing import Callable, TypeVar, overload
from contextlib import contextmanager
T = TypeVar("T")
def truncate_json( str ):
@ -180,4 +182,309 @@ def to_device(x: T | None, *args, **kwargs) -> T:
return tree_map(lambda t: t.to(*args, **kwargs), x)
def coalese( *arg, return_last=True ):
return [ x for x in arg if x is not None ][-1 if return_last else 0]
return [ x for x in arg if x is not None ][-1 if return_last else 0]
# checks if a module name is within a given whitelist/blacklist policy dict
def passes_policy( policy, name ):
if policy is None:
return True
if "exclude" in policy:
for term in policy["exclude"]:
if term in name:
return False
if "include" in policy:
for term in policy["include"]:
if term in name:
return True
return False
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
def autocast(input, from_dtype, to_dtype):
if input.dtype == from_dtype:
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
@contextmanager
def autocasts(input, from_dtype, to_dtype):
if input.dtype in from_dtype:
from_dtype = input.dtype
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
return func( self, k, *args, **kwargs )
return wrapper
# handles migrating an input tensor to a given devicve
def auto_to_forward( module, device=None ):
if device is None:
device = next(module.parameters()).device
func = module.forward
def wrapper( self, *args, **kwargs ):
# search through args and kwargs for any Tensor arguments
args = [*args]
for i, arg in enumerate(args):
if not isinstance( arg, torch.Tensor ):
continue
args[i] = arg.to( device=device )
for k, v in kwargs.items():
if not isinstance( v, torch.Tensor ):
continue
kwargs[k] = v.to( device=device )
return func( self, *args, **kwargs )
return wrapper
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
# generalizing this would be super sugoi but the there's no catch all for arguments
def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
in_features = m.in_features,
out_features = m.out_features,
bias = m.bias is not None,
) if not bnb else dict(
input_features=m.in_features,
output_features=m.out_features,
bias=m.bias is not None,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
num_embeddings=m.num_embeddings,
embedding_dim=m.embedding_dim,
padding_idx=m.padding_idx,
max_norm=m.max_norm,
norm_type=m.norm_type,
scale_grad_by_freq=m.scale_grad_by_freq,
sparse=m.sparse,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
# cannot feasibly do default arguments here sad
def replace_attention( model, klass, target, mode="math", verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
config = m.config,
layer_idx = m.layer_idx,
mode = mode,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
# trim/expand a tensor (for example, in a state dict)
def resize_weight( weight, target, dim=0, random=True ):
# trim
if target < weight.shape[dim]:
return weight[:target]
# expand
if target > weight.shape[dim]:
fn = torch.rand if random else torch.zeros
return torch.stack(
[ x for x in weight ] +
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
)
return weight
# grabs the memory properties of a given device
def get_device_properties( device ):
if 'cuda' in device:
props = torch.cuda.get_device_properties(device)
free, total = torch.cuda.mem_get_info(device)
else:
props = psutil.virtual_memory()
free, total = props.available, props.total
return {"name": device, "total": total, "free": free, "props": props}
# gets the rough size for a given module's parameters
def get_module_size( module ):
param_size = sum([p.nelement() * p.element_size() for p in module.parameters()])
buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()])
return param_size + buffer_size
# assigns modules to requested devices for a given policy
def get_model_offload_policy(module, policy=None):
# handle any other weird values this is set to
if not isinstance(policy, dict):
policy = {}
# default to only include the core model, and not the other modules (embeddings) in the splitting policy
if "include" not in policy:
policy["include"] = ["model"]
if "limits" not in policy:
policy["limits"] = []
if "devices" not in policy:
policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget
# create initial device info
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
# filter
modules = [ (name, size) for name, size in modules if name and size ]
total_size = sum([size for name, size in modules])
# set caps if requested in the policy
for i, cap in enumerate(policy["limits"]):
# no limit, skip
if cap <= 0:
continue
# is fractional, scale to total size
if cap < 1:
cap = math.floor(total_size * cap)
# available space is below cap, don't set
if devices[i]["free"] < cap:
continue
# cap to requested size
devices[i]["free"] = cap
device_index = 0
module_index = 0
while module_index < len(modules) and device_index < len(devices):
device = devices[device_index]
name, size = modules[module_index]
# fits within budget
if device["free"] - size >= 0:
device["modules"].append( name )
device["free"] -= size
module_index += 1
# does not fit in budget, increase device index
else:
device_index += 1
print(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
# to-do: check that all modules are exhausted
assert module_index >= len(modules)
# only return devices with modules assigned
return [ device for device in devices if device["modules"] ]
# handles naively splitting a model's layers across multiple devices
# this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU
def offload_model( model, policy=None ):
policy = get_model_offload_policy(model, policy=policy)
# move modules to respective devices
for i, device in enumerate( policy ):
# nothing assigned, skip
if not device["modules"]:
continue
for name in device["modules"]:
module = model.get_submodule(name)
module = module.to( device["name"] )
"""
# in case the above doesn't actually do what's requested
*parent, key = name.split(".")
module = getattr( model.get_submodule(name), key )
setattr( model.get_submodule(name), key, module.to( device["name"] ) )
"""
# select next device to cast inputs to, or wrap to first if last device
next_device = policy[i + 1]["name"] if i + 1 < len( policy ) else policy[0]["name"]
# same device, don't bother wrapping
if device["name"] == next_device:
continue
# wrap forward call
last_module = model.get_submodule( device["modules"][-1] )
last_module.forward = auto_to_forward(last_module, next_device)
"""
# Validate that the layers are all in the right spot
for name, module in model.named_modules():
if not not [*module.named_children()]:
continue
try:
print( name, next(module.parameters()).device )
except Exception as e:
pass
"""
return model

View File

@ -57,34 +57,6 @@ elif cfg.optimizations.dadaptation:
SGD = dadaptation.DAdaptSGD
AdaGrad = dadaptation.DAdaptAdaGrad
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
def autocast(input, from_dtype, to_dtype):
if input.dtype == from_dtype:
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
@contextmanager
def autocasts(input, from_dtype, to_dtype):
if input.dtype in from_dtype:
from_dtype = input.dtype
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
return func( self, k, *args, **kwargs )
return wrapper
Embedding.forward = autocast_forward(Embedding.forward)
if cfg.optimizations.fp8:
import transformer_engine.pytorch as te
@ -110,123 +82,6 @@ if cfg.optimizations.injects:
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
# generalizing this would be super sugoi but the there's no catch all for arguments
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
in_features = m.in_features,
out_features = m.out_features,
bias = m.bias is not None,
) if not bnb else dict(
input_features=m.in_features,
output_features=m.out_features,
bias=m.bias is not None,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
num_embeddings=m.num_embeddings,
embedding_dim=m.embedding_dim,
padding_idx=m.padding_idx,
max_norm=m.max_norm,
norm_type=m.norm_type,
scale_grad_by_freq=m.scale_grad_by_freq,
sparse=m.sparse,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
# cannot feasibly do default arguments here sad
def replace_attention( model, klass, target, mode="math", verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
config = m.config,
layer_idx = m.layer_idx,
mode = mode,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
# trim/expand a tensor (for example, in a state dict)
def resize_weight( weight, target, dim=0, random=True ):
# trim
if target < weight.shape[dim]:
return weight[:target]
# expand
if target > weight.shape[dim]:
fn = torch.rand if random else torch.zeros
return torch.stack(
[ x for x in weight ] +
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
)
return weight
# https://github.com/konstmish/prodigy
try:
from prodigyopt import Prodigy
@ -239,4 +94,23 @@ try:
import schedulefree
except Exception as e:
print('Error while importing Schedule_Free:', str(e))
pass
pass
# backwards compat
from .utils import (
autocast_forward,
auto_to_forward,
replace_linear,
replace_embedding as replace_embedding_old,
replace_attention as replace_attention_old,
resize_weight,
offload_model,
)
# wrapped here so we can maintain default args
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
return replace_embedding_old( model, klass, target, verbose )
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
return replace_attention_old( model, klass, target, verbose )
Embedding.forward = autocast_forward(Embedding.forward)