diff --git a/vall_e/config.py b/vall_e/config.py index fca5917..c0e0732 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -680,6 +680,10 @@ class Optimizations: bitnet: bool = False # use bitnet fp8: bool = False # use fp8 + model_offloading: dict | None = None # automatically splits the model over a list of devices + # example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU + # example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50% + @dataclass() class Config(BaseConfig): device: str = "cuda" # target device diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 9aaf635..1a87c42 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -235,4 +235,8 @@ def load_engines(training=True): for name, engine in engines.items(): engine.freeze(freeze_all=False) + # split models over requested devices + if cfg.optimizations.model_offloading: + engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading ) + return engines diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 2320a2a..d44ea2a 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -522,6 +522,16 @@ def example_usage(): if cfg.optimizations.replace and cfg.optimizations.embedding: model = ml.replace_embedding( model ) + + """ + cfg.optimizations.model_offloading = { + "devices": ["cuda:0", "cpu"], + "limits": [ 0.5, -1 ] + # "limits": [ 256 * (1024 ** 2), -1 ] + } + """ + if cfg.optimizations.model_offloading: + model = ml.offload_model( model, policy=cfg.optimizations.model_offloading ) engine = Engine(model=model, optimizer=optimizer) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index e99d0bf..e52decb 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -250,6 +250,7 @@ class AudioClassifier(nn.Module): xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ] # pad if needed + # to-do: validate that this causes ZERO issues max_size = max([ x.shape[-1] for x in xi ]) xi = [ #x if l == 0 else diff --git a/vall_e/models/lora.py b/vall_e/models/lora.py index ec7314d..87e6d7c 100644 --- a/vall_e/models/lora.py +++ b/vall_e/models/lora.py @@ -11,6 +11,8 @@ from torch import Tensor, nn import math from typing import Optional, List +from ..utils import passes_policy + # LoRA Linear for replacement # Pros: simple, just needs to reuse the replace_linear and copy weights # Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you) @@ -144,22 +146,6 @@ class ParameterizedLoRA(nn.Module): # M$'s LoRA class arranges things to where this isn't necessary return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) -def passes_policy( policy, name ): - if policy is None: - return True - - if "exclude" in policy: - for term in policy["exclude"]: - if term in name: - return False - - if "include" in policy: - for term in policy["include"]: - if term in name: - return True - - return False - def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype diff --git a/vall_e/utils/__init__.py b/vall_e/utils/__init__.py index b2f2ef9..77c0698 100755 --- a/vall_e/utils/__init__.py +++ b/vall_e/utils/__init__.py @@ -8,4 +8,5 @@ from .utils import ( tree_map, do_gc, set_seed, + passes_policy ) \ No newline at end of file diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index d8199f3..366eb2f 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -12,6 +12,8 @@ import re import torch import random import time +import psutil +import math from coloredlogs import ColoredFormatter from logging import StreamHandler @@ -19,7 +21,7 @@ from pathlib import Path from torch import Tensor, nn from tqdm.auto import tqdm from typing import Callable, TypeVar, overload - +from contextlib import contextmanager T = TypeVar("T") def truncate_json( str ): @@ -180,4 +182,309 @@ def to_device(x: T | None, *args, **kwargs) -> T: return tree_map(lambda t: t.to(*args, **kwargs), x) def coalese( *arg, return_last=True ): - return [ x for x in arg if x is not None ][-1 if return_last else 0] \ No newline at end of file + return [ x for x in arg if x is not None ][-1 if return_last else 0] + +# checks if a module name is within a given whitelist/blacklist policy dict +def passes_policy( policy, name ): + if policy is None: + return True + + if "exclude" in policy: + for term in policy["exclude"]: + if term in name: + return False + + if "include" in policy: + for term in policy["include"]: + if term in name: + return True + + return False + +# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) +@contextmanager +def autocast(input, from_dtype, to_dtype): + if input.dtype == from_dtype: + input = input.to(to_dtype) + yield input + input = input.to(from_dtype) + else: + yield input + +@contextmanager +def autocasts(input, from_dtype, to_dtype): + if input.dtype in from_dtype: + from_dtype = input.dtype + input = input.to(to_dtype) + yield input + input = input.to(from_dtype) + else: + yield input + +# handles temporarily upcasting 'index tensors' so torch will stop bitching +def autocast_forward( func ): + def wrapper( self, input, *args, **kwargs ): + with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k: + return func( self, k, *args, **kwargs ) + return wrapper + +# handles migrating an input tensor to a given devicve +def auto_to_forward( module, device=None ): + if device is None: + device = next(module.parameters()).device + + func = module.forward + + def wrapper( self, *args, **kwargs ): + # search through args and kwargs for any Tensor arguments + args = [*args] + for i, arg in enumerate(args): + if not isinstance( arg, torch.Tensor ): + continue + args[i] = arg.to( device=device ) + + for k, v in kwargs.items(): + if not isinstance( v, torch.Tensor ): + continue + kwargs[k] = v.to( device=device ) + + return func( self, *args, **kwargs ) + return wrapper + +# disgusting kludge, but it works (just realized BitNet has its own replacement routine) +# generalizing this would be super sugoi but the there's no catch all for arguments +def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ): + bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet + + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] + + for *parent, k in modules: + name = '.'.join(parent) + + m = getattr( model.get_submodule(name), k ) + + if isinstance(m, klass): + continue + + kwargs = dict( + in_features = m.in_features, + out_features = m.out_features, + bias = m.bias is not None, + ) if not bnb else dict( + input_features=m.in_features, + output_features=m.out_features, + bias=m.bias is not None, + ) + + # overwrite + setattr( + model.get_submodule(name), k, + klass( **kwargs ).to(device=device, dtype=dtype) + ) + + if verbose: + print(f"Replacing {name}.{k} to", klass) + + return model + +def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] + + for *parent, k in modules: + name = '.'.join(parent) + + m = getattr( model.get_submodule(name), k ) + + if isinstance(m, klass): + continue + + kwargs = dict( + num_embeddings=m.num_embeddings, + embedding_dim=m.embedding_dim, + padding_idx=m.padding_idx, + max_norm=m.max_norm, + norm_type=m.norm_type, + scale_grad_by_freq=m.scale_grad_by_freq, + sparse=m.sparse, + ) + + # overwrite + setattr( + model.get_submodule(name), k, + klass( **kwargs ).to(device=device, dtype=dtype) + ) + + if verbose: + print(f"Replacing {name}.{k} to", klass) + + return model + +# cannot feasibly do default arguments here sad +def replace_attention( model, klass, target, mode="math", verbose=False ): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] + + for *parent, k in modules: + name = '.'.join(parent) + + m = getattr( model.get_submodule(name), k ) + + if isinstance(m, klass): + continue + + kwargs = dict( + config = m.config, + layer_idx = m.layer_idx, + mode = mode, + ) + # overwrite + setattr( + model.get_submodule(name), k, + klass( **kwargs ).to(device=device, dtype=dtype) + ) + + if verbose: + print(f"Replacing {name}.{k} to", klass) + + return model + +# trim/expand a tensor (for example, in a state dict) +def resize_weight( weight, target, dim=0, random=True ): + # trim + if target < weight.shape[dim]: + return weight[:target] + # expand + if target > weight.shape[dim]: + fn = torch.rand if random else torch.zeros + return torch.stack( + [ x for x in weight ] + + [ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ] + ) + + return weight + +# grabs the memory properties of a given device +def get_device_properties( device ): + if 'cuda' in device: + props = torch.cuda.get_device_properties(device) + free, total = torch.cuda.mem_get_info(device) + else: + props = psutil.virtual_memory() + free, total = props.available, props.total + + return {"name": device, "total": total, "free": free, "props": props} + +# gets the rough size for a given module's parameters +def get_module_size( module ): + param_size = sum([p.nelement() * p.element_size() for p in module.parameters()]) + buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()]) + return param_size + buffer_size + +# assigns modules to requested devices for a given policy +def get_model_offload_policy(module, policy=None): + # handle any other weird values this is set to + if not isinstance(policy, dict): + policy = {} + + # default to only include the core model, and not the other modules (embeddings) in the splitting policy + if "include" not in policy: + policy["include"] = ["model"] + if "limits" not in policy: + policy["limits"] = [] + + if "devices" not in policy: + policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget + + # create initial device info + devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ] + modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ] + # filter + modules = [ (name, size) for name, size in modules if name and size ] + + total_size = sum([size for name, size in modules]) + + # set caps if requested in the policy + for i, cap in enumerate(policy["limits"]): + # no limit, skip + if cap <= 0: + continue + # is fractional, scale to total size + if cap < 1: + cap = math.floor(total_size * cap) + # available space is below cap, don't set + if devices[i]["free"] < cap: + continue + # cap to requested size + devices[i]["free"] = cap + + device_index = 0 + module_index = 0 + while module_index < len(modules) and device_index < len(devices): + device = devices[device_index] + name, size = modules[module_index] + + # fits within budget + if device["free"] - size >= 0: + device["modules"].append( name ) + device["free"] -= size + module_index += 1 + # does not fit in budget, increase device index + else: + device_index += 1 + print(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB") + + # to-do: check that all modules are exhausted + assert module_index >= len(modules) + + # only return devices with modules assigned + return [ device for device in devices if device["modules"] ] + +# handles naively splitting a model's layers across multiple devices +# this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU +def offload_model( model, policy=None ): + policy = get_model_offload_policy(model, policy=policy) + + # move modules to respective devices + for i, device in enumerate( policy ): + # nothing assigned, skip + if not device["modules"]: + continue + + for name in device["modules"]: + module = model.get_submodule(name) + module = module.to( device["name"] ) + + """ + # in case the above doesn't actually do what's requested + *parent, key = name.split(".") + module = getattr( model.get_submodule(name), key ) + setattr( model.get_submodule(name), key, module.to( device["name"] ) ) + """ + + # select next device to cast inputs to, or wrap to first if last device + next_device = policy[i + 1]["name"] if i + 1 < len( policy ) else policy[0]["name"] + # same device, don't bother wrapping + if device["name"] == next_device: + continue + + # wrap forward call + last_module = model.get_submodule( device["modules"][-1] ) + last_module.forward = auto_to_forward(last_module, next_device) + + """ + # Validate that the layers are all in the right spot + for name, module in model.named_modules(): + if not not [*module.named_children()]: + continue + try: + print( name, next(module.parameters()).device ) + except Exception as e: + pass + """ + + return model \ No newline at end of file diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index c2a6dff..03ab449 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -57,34 +57,6 @@ elif cfg.optimizations.dadaptation: SGD = dadaptation.DAdaptSGD AdaGrad = dadaptation.DAdaptAdaGrad -# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) -@contextmanager -def autocast(input, from_dtype, to_dtype): - if input.dtype == from_dtype: - input = input.to(to_dtype) - yield input - input = input.to(from_dtype) - else: - yield input - -@contextmanager -def autocasts(input, from_dtype, to_dtype): - if input.dtype in from_dtype: - from_dtype = input.dtype - input = input.to(to_dtype) - yield input - input = input.to(from_dtype) - else: - yield input - -# handles temporarily upcasting 'index tensors' so torch will stop bitching -def autocast_forward( func ): - def wrapper( self, input, *args, **kwargs ): - with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k: - return func( self, k, *args, **kwargs ) - return wrapper -Embedding.forward = autocast_forward(Embedding.forward) - if cfg.optimizations.fp8: import transformer_engine.pytorch as te @@ -110,123 +82,6 @@ if cfg.optimizations.injects: torch.optim.AdamW = AdamW torch.optim.SGD = SGD -# disgusting kludge, but it works (just realized BitNet has its own replacement routine) -# generalizing this would be super sugoi but the there's no catch all for arguments -def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ): - bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet - - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] - - for *parent, k in modules: - name = '.'.join(parent) - - m = getattr( model.get_submodule(name), k ) - - if isinstance(m, klass): - continue - - kwargs = dict( - in_features = m.in_features, - out_features = m.out_features, - bias = m.bias is not None, - ) if not bnb else dict( - input_features=m.in_features, - output_features=m.out_features, - bias=m.bias is not None, - ) - - # overwrite - setattr( - model.get_submodule(name), k, - klass( **kwargs ).to(device=device, dtype=dtype) - ) - - if verbose: - print(f"Replacing {name}.{k} to", klass) - - return model - -def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ): - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] - - for *parent, k in modules: - name = '.'.join(parent) - - m = getattr( model.get_submodule(name), k ) - - if isinstance(m, klass): - continue - - kwargs = dict( - num_embeddings=m.num_embeddings, - embedding_dim=m.embedding_dim, - padding_idx=m.padding_idx, - max_norm=m.max_norm, - norm_type=m.norm_type, - scale_grad_by_freq=m.scale_grad_by_freq, - sparse=m.sparse, - ) - - # overwrite - setattr( - model.get_submodule(name), k, - klass( **kwargs ).to(device=device, dtype=dtype) - ) - - if verbose: - print(f"Replacing {name}.{k} to", klass) - - return model - -# cannot feasibly do default arguments here sad -def replace_attention( model, klass, target, mode="math", verbose=False ): - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] - - for *parent, k in modules: - name = '.'.join(parent) - - m = getattr( model.get_submodule(name), k ) - - if isinstance(m, klass): - continue - - kwargs = dict( - config = m.config, - layer_idx = m.layer_idx, - mode = mode, - ) - # overwrite - setattr( - model.get_submodule(name), k, - klass( **kwargs ).to(device=device, dtype=dtype) - ) - - if verbose: - print(f"Replacing {name}.{k} to", klass) - - return model - -# trim/expand a tensor (for example, in a state dict) -def resize_weight( weight, target, dim=0, random=True ): - # trim - if target < weight.shape[dim]: - return weight[:target] - # expand - if target > weight.shape[dim]: - fn = torch.rand if random else torch.zeros - return torch.stack( - [ x for x in weight ] + - [ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ] - ) - - return weight - # https://github.com/konstmish/prodigy try: from prodigyopt import Prodigy @@ -239,4 +94,23 @@ try: import schedulefree except Exception as e: print('Error while importing Schedule_Free:', str(e)) - pass \ No newline at end of file + pass + +# backwards compat +from .utils import ( + autocast_forward, + auto_to_forward, + replace_linear, + replace_embedding as replace_embedding_old, + replace_attention as replace_attention_old, + resize_weight, + offload_model, +) + +# wrapped here so we can maintain default args +def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ): + return replace_embedding_old( model, klass, target, verbose ) +def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ): + return replace_attention_old( model, klass, target, verbose ) + +Embedding.forward = autocast_forward(Embedding.forward) \ No newline at end of file