diff --git a/CHANGELOG.md b/CHANGELOG.md index d12af22..fa20b15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,7 +43,13 @@ Docs: Features: - Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer. - Added AdamW (copy of Adam with weight decay init 1e-2). #10 + - Introduced ModuleConfig overrides which can be seamlessly be used at initialization time of a module. + - Added `bnb.nn.Embedding` layer which runs at 32-bit but without the layernorm. This works well if you need to fine-tune pretrained models that do not have a embedding layer norm. #19 Bug fixes: - Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13 - Fixed an unsafe use of eval. #8 + - Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15 + +Docs: + - Added instructions how to solve "\_\_fatbinwrap_" errors. diff --git a/README.md b/README.md index 4a731b0..4b7db17 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ For upcoming features and changes and full history see [Patch Notes](CHANGELOG.m ## Errors 1. RuntimeError: CUDA error: no kernel image is available for execution on the device. [Solution](errors_and_solutions.md#No-kernel-image-available) +2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_) ## Compile from source diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 177540f..27ad6ca 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import StableEmbedding +from .modules import StableEmbedding, Embedding diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ce2f3a4..dc0a171 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -18,8 +18,7 @@ class StableEmbedding(torch.nn.Embedding): sparse: bool = False, _weight: Optional[Tensor] = None) -> None: super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) self.norm = torch.nn.LayerNorm(embedding_dim) - GlobalOptimManager.get_instance().register_parameters(self.weight) - GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32) + GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) @@ -42,3 +41,33 @@ class StableEmbedding(torch.nn.Embedding): self.norm_type, self.scale_grad_by_freq, self.sparse) return self.norm(emb) + + +class Embedding(torch.nn.Embedding): + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, + sparse: bool = False, _weight: Optional[Tensor] = None) -> None: + super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) + GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) + + def reset_parameters(self) -> None: + torch.nn.init.xavier_uniform_(self.weight) + self._fill_padding_idx_with_zero() + + ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + to make the Layer compatible with Pytorch < 1.9. + This means that if this changes in future PyTorch releases this need to change too + which is cumbersome. However, with this we can ensure compatibility with previous + PyTorch releases. + ''' + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + emb = F.embedding( + input, self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + + return emb diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index cfbd72e..5a5bb1e 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -26,6 +26,7 @@ class GlobalOptimManager(object): self.index2config = {} self.optimizer = None self.uses_config_override = False + self.module_weight_config_triple = [] @classmethod def get_instance(cls): @@ -77,12 +78,16 @@ class GlobalOptimManager(object): if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict) else: self.pid2config[id(p)] = key_value_dict + def register_module_override(self, module, param_name, config): + self.module_weight_config_triple.append((module, param_name, config)) + + class Optimizer8bit(torch.optim.Optimizer): def __init__(self, params, defaults, optim_bits=32): super(Optimizer8bit, self).__init__(params, defaults) - self.checked_if_on_gpu = False + self.initialized = False self.name2qmap = {} self.mng = GlobalOptimManager.get_instance() @@ -172,7 +177,6 @@ class Optimizer8bit(torch.optim.Optimizer): self.__setstate__({'state': state, 'param_groups': param_groups}) def to_gpu(self): - self.checked_if_on_gpu = True for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group['params']): if p in self.state: @@ -181,6 +185,23 @@ class Optimizer8bit(torch.optim.Optimizer): if isinstance(v, torch.Tensor): self.state[p][k] = v.to(p.device) + def check_overrides(self): + for module, attr, config in self.mng.module_weight_config_triple: + pmodule = getattr(module, attr) + assert pmodule is not None + assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter) + found = False + for gindex, group in enumerate(self.param_groups): + if found: break + for pindex, p in enumerate(group['params']): + if found: break + if id(p) == id(pmodule): + # found the matching parameter + # init override + self.mng.pid2config[id(p)] = config + self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)] + found = True + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -196,7 +217,11 @@ class Optimizer8bit(torch.optim.Optimizer): overflows = [] - if not self.checked_if_on_gpu: self.to_gpu() # needed for fairseq pure fp16 training + if not self.initialized: + self.check_overrides() + self.to_gpu() # needed for fairseq pure fp16 training + self.initialized = True + for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group['params']): if p.grad is None: diff --git a/errors_and_solutions.md b/errors_and_solutions.md index dd99f7c..5e8b2d2 100644 --- a/errors_and_solutions.md +++ b/errors_and_solutions.md @@ -6,3 +6,16 @@ If you are feeling lucky, you can also try to compile the library from source. T __If you encounter any other error not listed here please create an issue. This will help resolve your problem and will help out others in the future. + + +# fatbinwrap + +This error occurs if there is a mismatch between CUDA versions in the C++ library and the CUDA part. Make sure you have right CUDA in your $PATH and $LD_LIBRARY_PATH variable. In the conda base environment you can find the library under: +```bash +ls $CONDA_PREFIX/lib/*cudart* +``` +Make sure this path is appended to the `LD_LIBRARY_PATH` so bnb can find the CUDA runtime environment library (cudart). + +If this does not fix the issue, please try [compilation from source](compile_from_source.md) next. + +If this does not work, please open an issue and paste the printed environment if you call `make` and the associated error when running bnb. diff --git a/howto_config_override.md b/howto_config_override.md index 11e9d49..4680776 100644 --- a/howto_config_override.md +++ b/howto_config_override.md @@ -2,6 +2,7 @@ If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details +For global overrides in many different places in your code you can do: ```python import torch import bitsandbytes as bnb @@ -24,3 +25,16 @@ mng.override_config([model.special.weight, model.also_special.weight], key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)}) ``` Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm` + +For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager: +```python +class MyModule(torch.nn.Module): + def __init__(din, dout): + super(MyModule, self).__init__() + self.linear = torch.nn.Linear(din, dout) + # optimization will happen in 32-bit and + # learning rate will be set to 0.0001 independent of the main learning rate + config = {'optim_bits': 32, 'lr' : 0.0001} + GlobalOptimManager.get_instance().register_module_override(self, 'weight', config) + +``` diff --git a/tests/test_modules.py b/tests/test_modules.py new file mode 100644 index 0000000..6cbee7b --- /dev/null +++ b/tests/test_modules.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +import bitsandbytes as bnb + +from itertools import product + +from bitsandbytes import functional as F + + +@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding']) +def test_embeddings(embcls): + bnb.optim.GlobalOptimManager.get_instance().initialize() + emb1 = torch.nn.Embedding(100, 512).cuda() + emb2 = embcls(100, 512).cuda() + + adam1 = bnb.optim.Adam8bit(emb1.parameters()) + adam2 = bnb.optim.Adam8bit(emb2.parameters()) + + batches = torch.randint(1, 100, size=(100, 4, 32)).cuda() + + for i in range(100): + batch = batches[i] + + embedded1 = emb1(batch) + embedded2 = emb2(batch) + + l1 = embedded1.mean() + l2 = embedded2.mean() + + l1.backward() + l2.backward() + + adam1.step() + adam2.step() + + adam1.zero_grad() + adam2.zero_grad() + + assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8 + assert adam2.state[emb2.weight]['state1'].dtype == torch.float32 + +