naive model offloading support (handles automatically splitting parts of the model to requested device per memory constraints, either inferred or requested in the yaml, input tensors are automatically migrated to the right device, it SEEMS to work for training under the test trainer when split between GPU and CPU) (this was specifically only because that Flux imagegen model released so I can test it there)
This commit is contained in:
parent
387358bc8a
commit
b4c895114c
|
@ -680,6 +680,10 @@ class Optimizations:
|
||||||
bitnet: bool = False # use bitnet
|
bitnet: bool = False # use bitnet
|
||||||
fp8: bool = False # use fp8
|
fp8: bool = False # use fp8
|
||||||
|
|
||||||
|
model_offloading: dict | None = None # automatically splits the model over a list of devices
|
||||||
|
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
|
||||||
|
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
device: str = "cuda" # target device
|
device: str = "cuda" # target device
|
||||||
|
|
|
@ -235,4 +235,8 @@ def load_engines(training=True):
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
engine.freeze(freeze_all=False)
|
engine.freeze(freeze_all=False)
|
||||||
|
|
||||||
|
# split models over requested devices
|
||||||
|
if cfg.optimizations.model_offloading:
|
||||||
|
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
|
||||||
|
|
||||||
return engines
|
return engines
|
||||||
|
|
|
@ -523,6 +523,16 @@ def example_usage():
|
||||||
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
||||||
model = ml.replace_embedding( model )
|
model = ml.replace_embedding( model )
|
||||||
|
|
||||||
|
"""
|
||||||
|
cfg.optimizations.model_offloading = {
|
||||||
|
"devices": ["cuda:0", "cpu"],
|
||||||
|
"limits": [ 0.5, -1 ]
|
||||||
|
# "limits": [ 256 * (1024 ** 2), -1 ]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if cfg.optimizations.model_offloading:
|
||||||
|
model = ml.offload_model( model, policy=cfg.optimizations.model_offloading )
|
||||||
|
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
engines = Engines({"ar+nar": engine})
|
engines = Engines({"ar+nar": engine})
|
||||||
|
|
|
@ -250,6 +250,7 @@ class AudioClassifier(nn.Module):
|
||||||
|
|
||||||
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
||||||
# pad if needed
|
# pad if needed
|
||||||
|
# to-do: validate that this causes ZERO issues
|
||||||
max_size = max([ x.shape[-1] for x in xi ])
|
max_size = max([ x.shape[-1] for x in xi ])
|
||||||
xi = [
|
xi = [
|
||||||
#x if l == 0 else
|
#x if l == 0 else
|
||||||
|
|
|
@ -11,6 +11,8 @@ from torch import Tensor, nn
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from ..utils import passes_policy
|
||||||
|
|
||||||
# LoRA Linear for replacement
|
# LoRA Linear for replacement
|
||||||
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
||||||
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
|
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
|
||||||
|
@ -144,22 +146,6 @@ class ParameterizedLoRA(nn.Module):
|
||||||
# M$'s LoRA class arranges things to where this isn't necessary
|
# M$'s LoRA class arranges things to where this isn't necessary
|
||||||
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
def passes_policy( policy, name ):
|
|
||||||
if policy is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if "exclude" in policy:
|
|
||||||
for term in policy["exclude"]:
|
|
||||||
if term in name:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if "include" in policy:
|
|
||||||
for term in policy["include"]:
|
|
||||||
if term in name:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
|
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
dtype = next(model.parameters()).dtype
|
dtype = next(model.parameters()).dtype
|
||||||
|
|
|
@ -8,4 +8,5 @@ from .utils import (
|
||||||
tree_map,
|
tree_map,
|
||||||
do_gc,
|
do_gc,
|
||||||
set_seed,
|
set_seed,
|
||||||
|
passes_policy
|
||||||
)
|
)
|
|
@ -12,6 +12,8 @@ import re
|
||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
import psutil
|
||||||
|
import math
|
||||||
|
|
||||||
from coloredlogs import ColoredFormatter
|
from coloredlogs import ColoredFormatter
|
||||||
from logging import StreamHandler
|
from logging import StreamHandler
|
||||||
|
@ -19,7 +21,7 @@ from pathlib import Path
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from typing import Callable, TypeVar, overload
|
from typing import Callable, TypeVar, overload
|
||||||
|
from contextlib import contextmanager
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
def truncate_json( str ):
|
def truncate_json( str ):
|
||||||
|
@ -181,3 +183,308 @@ def to_device(x: T | None, *args, **kwargs) -> T:
|
||||||
|
|
||||||
def coalese( *arg, return_last=True ):
|
def coalese( *arg, return_last=True ):
|
||||||
return [ x for x in arg if x is not None ][-1 if return_last else 0]
|
return [ x for x in arg if x is not None ][-1 if return_last else 0]
|
||||||
|
|
||||||
|
# checks if a module name is within a given whitelist/blacklist policy dict
|
||||||
|
def passes_policy( policy, name ):
|
||||||
|
if policy is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if "exclude" in policy:
|
||||||
|
for term in policy["exclude"]:
|
||||||
|
if term in name:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if "include" in policy:
|
||||||
|
for term in policy["include"]:
|
||||||
|
if term in name:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
|
||||||
|
@contextmanager
|
||||||
|
def autocast(input, from_dtype, to_dtype):
|
||||||
|
if input.dtype == from_dtype:
|
||||||
|
input = input.to(to_dtype)
|
||||||
|
yield input
|
||||||
|
input = input.to(from_dtype)
|
||||||
|
else:
|
||||||
|
yield input
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def autocasts(input, from_dtype, to_dtype):
|
||||||
|
if input.dtype in from_dtype:
|
||||||
|
from_dtype = input.dtype
|
||||||
|
input = input.to(to_dtype)
|
||||||
|
yield input
|
||||||
|
input = input.to(from_dtype)
|
||||||
|
else:
|
||||||
|
yield input
|
||||||
|
|
||||||
|
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
||||||
|
def autocast_forward( func ):
|
||||||
|
def wrapper( self, input, *args, **kwargs ):
|
||||||
|
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
|
||||||
|
return func( self, k, *args, **kwargs )
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
# handles migrating an input tensor to a given devicve
|
||||||
|
def auto_to_forward( module, device=None ):
|
||||||
|
if device is None:
|
||||||
|
device = next(module.parameters()).device
|
||||||
|
|
||||||
|
func = module.forward
|
||||||
|
|
||||||
|
def wrapper( self, *args, **kwargs ):
|
||||||
|
# search through args and kwargs for any Tensor arguments
|
||||||
|
args = [*args]
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if not isinstance( arg, torch.Tensor ):
|
||||||
|
continue
|
||||||
|
args[i] = arg.to( device=device )
|
||||||
|
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if not isinstance( v, torch.Tensor ):
|
||||||
|
continue
|
||||||
|
kwargs[k] = v.to( device=device )
|
||||||
|
|
||||||
|
return func( self, *args, **kwargs )
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||||
|
# generalizing this would be super sugoi but the there's no catch all for arguments
|
||||||
|
def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
|
||||||
|
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
dtype = next(model.parameters()).dtype
|
||||||
|
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||||
|
|
||||||
|
for *parent, k in modules:
|
||||||
|
name = '.'.join(parent)
|
||||||
|
|
||||||
|
m = getattr( model.get_submodule(name), k )
|
||||||
|
|
||||||
|
if isinstance(m, klass):
|
||||||
|
continue
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
in_features = m.in_features,
|
||||||
|
out_features = m.out_features,
|
||||||
|
bias = m.bias is not None,
|
||||||
|
) if not bnb else dict(
|
||||||
|
input_features=m.in_features,
|
||||||
|
output_features=m.out_features,
|
||||||
|
bias=m.bias is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# overwrite
|
||||||
|
setattr(
|
||||||
|
model.get_submodule(name), k,
|
||||||
|
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Replacing {name}.{k} to", klass)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
dtype = next(model.parameters()).dtype
|
||||||
|
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||||
|
|
||||||
|
for *parent, k in modules:
|
||||||
|
name = '.'.join(parent)
|
||||||
|
|
||||||
|
m = getattr( model.get_submodule(name), k )
|
||||||
|
|
||||||
|
if isinstance(m, klass):
|
||||||
|
continue
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
num_embeddings=m.num_embeddings,
|
||||||
|
embedding_dim=m.embedding_dim,
|
||||||
|
padding_idx=m.padding_idx,
|
||||||
|
max_norm=m.max_norm,
|
||||||
|
norm_type=m.norm_type,
|
||||||
|
scale_grad_by_freq=m.scale_grad_by_freq,
|
||||||
|
sparse=m.sparse,
|
||||||
|
)
|
||||||
|
|
||||||
|
# overwrite
|
||||||
|
setattr(
|
||||||
|
model.get_submodule(name), k,
|
||||||
|
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Replacing {name}.{k} to", klass)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
# cannot feasibly do default arguments here sad
|
||||||
|
def replace_attention( model, klass, target, mode="math", verbose=False ):
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
dtype = next(model.parameters()).dtype
|
||||||
|
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||||
|
|
||||||
|
for *parent, k in modules:
|
||||||
|
name = '.'.join(parent)
|
||||||
|
|
||||||
|
m = getattr( model.get_submodule(name), k )
|
||||||
|
|
||||||
|
if isinstance(m, klass):
|
||||||
|
continue
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
config = m.config,
|
||||||
|
layer_idx = m.layer_idx,
|
||||||
|
mode = mode,
|
||||||
|
)
|
||||||
|
# overwrite
|
||||||
|
setattr(
|
||||||
|
model.get_submodule(name), k,
|
||||||
|
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Replacing {name}.{k} to", klass)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
# trim/expand a tensor (for example, in a state dict)
|
||||||
|
def resize_weight( weight, target, dim=0, random=True ):
|
||||||
|
# trim
|
||||||
|
if target < weight.shape[dim]:
|
||||||
|
return weight[:target]
|
||||||
|
# expand
|
||||||
|
if target > weight.shape[dim]:
|
||||||
|
fn = torch.rand if random else torch.zeros
|
||||||
|
return torch.stack(
|
||||||
|
[ x for x in weight ] +
|
||||||
|
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
|
||||||
|
)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
# grabs the memory properties of a given device
|
||||||
|
def get_device_properties( device ):
|
||||||
|
if 'cuda' in device:
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
free, total = torch.cuda.mem_get_info(device)
|
||||||
|
else:
|
||||||
|
props = psutil.virtual_memory()
|
||||||
|
free, total = props.available, props.total
|
||||||
|
|
||||||
|
return {"name": device, "total": total, "free": free, "props": props}
|
||||||
|
|
||||||
|
# gets the rough size for a given module's parameters
|
||||||
|
def get_module_size( module ):
|
||||||
|
param_size = sum([p.nelement() * p.element_size() for p in module.parameters()])
|
||||||
|
buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()])
|
||||||
|
return param_size + buffer_size
|
||||||
|
|
||||||
|
# assigns modules to requested devices for a given policy
|
||||||
|
def get_model_offload_policy(module, policy=None):
|
||||||
|
# handle any other weird values this is set to
|
||||||
|
if not isinstance(policy, dict):
|
||||||
|
policy = {}
|
||||||
|
|
||||||
|
# default to only include the core model, and not the other modules (embeddings) in the splitting policy
|
||||||
|
if "include" not in policy:
|
||||||
|
policy["include"] = ["model"]
|
||||||
|
if "limits" not in policy:
|
||||||
|
policy["limits"] = []
|
||||||
|
|
||||||
|
if "devices" not in policy:
|
||||||
|
policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget
|
||||||
|
|
||||||
|
# create initial device info
|
||||||
|
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
|
||||||
|
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
|
||||||
|
# filter
|
||||||
|
modules = [ (name, size) for name, size in modules if name and size ]
|
||||||
|
|
||||||
|
total_size = sum([size for name, size in modules])
|
||||||
|
|
||||||
|
# set caps if requested in the policy
|
||||||
|
for i, cap in enumerate(policy["limits"]):
|
||||||
|
# no limit, skip
|
||||||
|
if cap <= 0:
|
||||||
|
continue
|
||||||
|
# is fractional, scale to total size
|
||||||
|
if cap < 1:
|
||||||
|
cap = math.floor(total_size * cap)
|
||||||
|
# available space is below cap, don't set
|
||||||
|
if devices[i]["free"] < cap:
|
||||||
|
continue
|
||||||
|
# cap to requested size
|
||||||
|
devices[i]["free"] = cap
|
||||||
|
|
||||||
|
device_index = 0
|
||||||
|
module_index = 0
|
||||||
|
while module_index < len(modules) and device_index < len(devices):
|
||||||
|
device = devices[device_index]
|
||||||
|
name, size = modules[module_index]
|
||||||
|
|
||||||
|
# fits within budget
|
||||||
|
if device["free"] - size >= 0:
|
||||||
|
device["modules"].append( name )
|
||||||
|
device["free"] -= size
|
||||||
|
module_index += 1
|
||||||
|
# does not fit in budget, increase device index
|
||||||
|
else:
|
||||||
|
device_index += 1
|
||||||
|
print(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
|
||||||
|
|
||||||
|
# to-do: check that all modules are exhausted
|
||||||
|
assert module_index >= len(modules)
|
||||||
|
|
||||||
|
# only return devices with modules assigned
|
||||||
|
return [ device for device in devices if device["modules"] ]
|
||||||
|
|
||||||
|
# handles naively splitting a model's layers across multiple devices
|
||||||
|
# this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU
|
||||||
|
def offload_model( model, policy=None ):
|
||||||
|
policy = get_model_offload_policy(model, policy=policy)
|
||||||
|
|
||||||
|
# move modules to respective devices
|
||||||
|
for i, device in enumerate( policy ):
|
||||||
|
# nothing assigned, skip
|
||||||
|
if not device["modules"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for name in device["modules"]:
|
||||||
|
module = model.get_submodule(name)
|
||||||
|
module = module.to( device["name"] )
|
||||||
|
|
||||||
|
"""
|
||||||
|
# in case the above doesn't actually do what's requested
|
||||||
|
*parent, key = name.split(".")
|
||||||
|
module = getattr( model.get_submodule(name), key )
|
||||||
|
setattr( model.get_submodule(name), key, module.to( device["name"] ) )
|
||||||
|
"""
|
||||||
|
|
||||||
|
# select next device to cast inputs to, or wrap to first if last device
|
||||||
|
next_device = policy[i + 1]["name"] if i + 1 < len( policy ) else policy[0]["name"]
|
||||||
|
# same device, don't bother wrapping
|
||||||
|
if device["name"] == next_device:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# wrap forward call
|
||||||
|
last_module = model.get_submodule( device["modules"][-1] )
|
||||||
|
last_module.forward = auto_to_forward(last_module, next_device)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Validate that the layers are all in the right spot
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if not not [*module.named_children()]:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
print( name, next(module.parameters()).device )
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
return model
|
|
@ -57,34 +57,6 @@ elif cfg.optimizations.dadaptation:
|
||||||
SGD = dadaptation.DAdaptSGD
|
SGD = dadaptation.DAdaptSGD
|
||||||
AdaGrad = dadaptation.DAdaptAdaGrad
|
AdaGrad = dadaptation.DAdaptAdaGrad
|
||||||
|
|
||||||
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
|
|
||||||
@contextmanager
|
|
||||||
def autocast(input, from_dtype, to_dtype):
|
|
||||||
if input.dtype == from_dtype:
|
|
||||||
input = input.to(to_dtype)
|
|
||||||
yield input
|
|
||||||
input = input.to(from_dtype)
|
|
||||||
else:
|
|
||||||
yield input
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def autocasts(input, from_dtype, to_dtype):
|
|
||||||
if input.dtype in from_dtype:
|
|
||||||
from_dtype = input.dtype
|
|
||||||
input = input.to(to_dtype)
|
|
||||||
yield input
|
|
||||||
input = input.to(from_dtype)
|
|
||||||
else:
|
|
||||||
yield input
|
|
||||||
|
|
||||||
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
|
||||||
def autocast_forward( func ):
|
|
||||||
def wrapper( self, input, *args, **kwargs ):
|
|
||||||
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
|
|
||||||
return func( self, k, *args, **kwargs )
|
|
||||||
return wrapper
|
|
||||||
Embedding.forward = autocast_forward(Embedding.forward)
|
|
||||||
|
|
||||||
if cfg.optimizations.fp8:
|
if cfg.optimizations.fp8:
|
||||||
import transformer_engine.pytorch as te
|
import transformer_engine.pytorch as te
|
||||||
|
|
||||||
|
@ -110,123 +82,6 @@ if cfg.optimizations.injects:
|
||||||
torch.optim.AdamW = AdamW
|
torch.optim.AdamW = AdamW
|
||||||
torch.optim.SGD = SGD
|
torch.optim.SGD = SGD
|
||||||
|
|
||||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
|
||||||
# generalizing this would be super sugoi but the there's no catch all for arguments
|
|
||||||
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
|
|
||||||
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
dtype = next(model.parameters()).dtype
|
|
||||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
|
||||||
|
|
||||||
for *parent, k in modules:
|
|
||||||
name = '.'.join(parent)
|
|
||||||
|
|
||||||
m = getattr( model.get_submodule(name), k )
|
|
||||||
|
|
||||||
if isinstance(m, klass):
|
|
||||||
continue
|
|
||||||
|
|
||||||
kwargs = dict(
|
|
||||||
in_features = m.in_features,
|
|
||||||
out_features = m.out_features,
|
|
||||||
bias = m.bias is not None,
|
|
||||||
) if not bnb else dict(
|
|
||||||
input_features=m.in_features,
|
|
||||||
output_features=m.out_features,
|
|
||||||
bias=m.bias is not None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# overwrite
|
|
||||||
setattr(
|
|
||||||
model.get_submodule(name), k,
|
|
||||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"Replacing {name}.{k} to", klass)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
dtype = next(model.parameters()).dtype
|
|
||||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
|
||||||
|
|
||||||
for *parent, k in modules:
|
|
||||||
name = '.'.join(parent)
|
|
||||||
|
|
||||||
m = getattr( model.get_submodule(name), k )
|
|
||||||
|
|
||||||
if isinstance(m, klass):
|
|
||||||
continue
|
|
||||||
|
|
||||||
kwargs = dict(
|
|
||||||
num_embeddings=m.num_embeddings,
|
|
||||||
embedding_dim=m.embedding_dim,
|
|
||||||
padding_idx=m.padding_idx,
|
|
||||||
max_norm=m.max_norm,
|
|
||||||
norm_type=m.norm_type,
|
|
||||||
scale_grad_by_freq=m.scale_grad_by_freq,
|
|
||||||
sparse=m.sparse,
|
|
||||||
)
|
|
||||||
|
|
||||||
# overwrite
|
|
||||||
setattr(
|
|
||||||
model.get_submodule(name), k,
|
|
||||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"Replacing {name}.{k} to", klass)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
# cannot feasibly do default arguments here sad
|
|
||||||
def replace_attention( model, klass, target, mode="math", verbose=False ):
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
dtype = next(model.parameters()).dtype
|
|
||||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
|
||||||
|
|
||||||
for *parent, k in modules:
|
|
||||||
name = '.'.join(parent)
|
|
||||||
|
|
||||||
m = getattr( model.get_submodule(name), k )
|
|
||||||
|
|
||||||
if isinstance(m, klass):
|
|
||||||
continue
|
|
||||||
|
|
||||||
kwargs = dict(
|
|
||||||
config = m.config,
|
|
||||||
layer_idx = m.layer_idx,
|
|
||||||
mode = mode,
|
|
||||||
)
|
|
||||||
# overwrite
|
|
||||||
setattr(
|
|
||||||
model.get_submodule(name), k,
|
|
||||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"Replacing {name}.{k} to", klass)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
# trim/expand a tensor (for example, in a state dict)
|
|
||||||
def resize_weight( weight, target, dim=0, random=True ):
|
|
||||||
# trim
|
|
||||||
if target < weight.shape[dim]:
|
|
||||||
return weight[:target]
|
|
||||||
# expand
|
|
||||||
if target > weight.shape[dim]:
|
|
||||||
fn = torch.rand if random else torch.zeros
|
|
||||||
return torch.stack(
|
|
||||||
[ x for x in weight ] +
|
|
||||||
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
|
|
||||||
)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
# https://github.com/konstmish/prodigy
|
# https://github.com/konstmish/prodigy
|
||||||
try:
|
try:
|
||||||
from prodigyopt import Prodigy
|
from prodigyopt import Prodigy
|
||||||
|
@ -240,3 +95,22 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('Error while importing Schedule_Free:', str(e))
|
print('Error while importing Schedule_Free:', str(e))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# backwards compat
|
||||||
|
from .utils import (
|
||||||
|
autocast_forward,
|
||||||
|
auto_to_forward,
|
||||||
|
replace_linear,
|
||||||
|
replace_embedding as replace_embedding_old,
|
||||||
|
replace_attention as replace_attention_old,
|
||||||
|
resize_weight,
|
||||||
|
offload_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# wrapped here so we can maintain default args
|
||||||
|
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
|
||||||
|
return replace_embedding_old( model, klass, target, verbose )
|
||||||
|
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
|
||||||
|
return replace_attention_old( model, klass, target, verbose )
|
||||||
|
|
||||||
|
Embedding.forward = autocast_forward(Embedding.forward)
|
Loading…
Reference in New Issue
Block a user