diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 83d4f9d..6570d0a 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -527,8 +527,8 @@ def example_usage(): cfg.optimizations.model_offloading = { "devices": ["cuda:0", "cpu"], # "limits": [ 0.9, -1 ], - "assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]], - "limits": [ 256 * (1024 ** 2), -1 ] + "assign": [[ f'layers.{i}.' for i in range(0,10) ], [ f'layers.{i}.' for i in range(11,12) ] + [ "model.norm" ]], + # "limits": [ 256 * (1024 ** 2), -1 ] } """ diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 4faead2..6227da4 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -229,15 +229,22 @@ def autocast_forward( func ): return wrapper # handles migrating an input tensor to a given devicve -def auto_to_forward( module, device=None ): - if device is None: - device = next(module.parameters()).device - +def auto_align_inputs_forward( module, device=None, name = None ): func = module.forward - def wrapper( self, *args, **kwargs ): - # search through args and kwargs for any Tensor arguments + if device is None: + if hasattr( module, 'device' ): + device = module.device + else: + try: + device = next(module.parameters() if [*module.parameters()] else module.buffers()).device + except Exception as e: + return func + + + def wrapper( *args, **kwargs ): args = [*args] + # search through args and kwargs for any Tensor arguments for i, arg in enumerate(args): if not isinstance( arg, torch.Tensor ): continue @@ -248,7 +255,11 @@ def auto_to_forward( module, device=None ): continue kwargs[k] = v.to( device=device ) - return func( self, *args, **kwargs ) + # disgusting patch + if "position_embeddings" in kwargs: + kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ]) + + return func( *args, **kwargs ) return wrapper # disgusting kludge, but it works (just realized BitNet has its own replacement routine) @@ -404,8 +415,6 @@ def get_model_offload_policy(module, policy=None): if "devices" not in policy: policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget - print( policy ) - # create initial device info devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ] modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ] @@ -498,23 +507,14 @@ def offload_model( model, policy=None ): for name in device["modules"]: module = model.get_submodule(name) module = module.to( device["name"] ) + module.device = device['name'] - """ - # in case the above doesn't actually do what's requested - *parent, key = name.split(".") - module = getattr( model.get_submodule(name), key ) - setattr( model.get_submodule(name), key, module.to( device["name"] ) ) - """ - - # select next device to cast inputs to, or wrap to first if last device - next_device = policy[i + 1]["name"] if i + 1 < len( policy ) else policy[0]["name"] - # same device, don't bother wrapping - if device["name"] == next_device: + # wrap modules with forward to ensure all inputs are matched to its device + for name, module in model.named_modules(): + if not hasattr( module, 'forward' ): continue - # wrap forward call - last_module = model.get_submodule( device["modules"][-1] ) - last_module.forward = auto_to_forward(last_module, next_device) + module.forward = auto_align_inputs_forward(module) """ # Validate that the layers are all in the right spot @@ -524,6 +524,7 @@ def offload_model( model, policy=None ): try: print( name, next(module.parameters()).device ) except Exception as e: + print( name, "?" ) pass """ diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 03ab449..7d986e4 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -99,18 +99,17 @@ except Exception as e: # backwards compat from .utils import ( autocast_forward, - auto_to_forward, - replace_linear, + replace_linear as replace_linear_old, replace_embedding as replace_embedding_old, - replace_attention as replace_attention_old, + replace_attention, resize_weight, offload_model, ) # wrapped here so we can maintain default args def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ): - return replace_embedding_old( model, klass, target, verbose ) + return replace_linear_old( model, klass, target, verbose ) def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ): - return replace_attention_old( model, klass, target, verbose ) + return replace_embedding_old( model, klass, target, verbose ) Embedding.forward = autocast_forward(Embedding.forward) \ No newline at end of file