Added module override, bnb.nn.Embedding #13 #15 #19

This commit is contained in:
Tim Dettmers 2021-11-29 09:32:13 -08:00
parent 3cff6795fb
commit 20e1677dfd
8 changed files with 140 additions and 6 deletions

View File

@ -43,7 +43,13 @@ Docs:
Features:
- Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer.
- Added AdamW (copy of Adam with weight decay init 1e-2). #10
- Introduced ModuleConfig overrides which can be seamlessly be used at initialization time of a module.
- Added `bnb.nn.Embedding` layer which runs at 32-bit but without the layernorm. This works well if you need to fine-tune pretrained models that do not have a embedding layer norm. #19
Bug fixes:
- Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13
- Fixed an unsafe use of eval. #8
- Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15
Docs:
- Added instructions how to solve "\_\_fatbinwrap_" errors.

View File

@ -83,6 +83,7 @@ For upcoming features and changes and full history see [Patch Notes](CHANGELOG.m
## Errors
1. RuntimeError: CUDA error: no kernel image is available for execution on the device. [Solution](errors_and_solutions.md#No-kernel-image-available)
2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_)
## Compile from source

View File

@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import StableEmbedding
from .modules import StableEmbedding, Embedding

View File

@ -18,8 +18,7 @@ class StableEmbedding(torch.nn.Embedding):
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
self.norm = torch.nn.LayerNorm(embedding_dim)
GlobalOptimManager.get_instance().register_parameters(self.weight)
GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32)
GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
@ -42,3 +41,33 @@ class StableEmbedding(torch.nn.Embedding):
self.norm_type, self.scale_grad_by_freq, self.sparse)
return self.norm(emb)
class Embedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
'''
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
input, self.weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
return emb

View File

@ -26,6 +26,7 @@ class GlobalOptimManager(object):
self.index2config = {}
self.optimizer = None
self.uses_config_override = False
self.module_weight_config_triple = []
@classmethod
def get_instance(cls):
@ -77,12 +78,16 @@ class GlobalOptimManager(object):
if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
else: self.pid2config[id(p)] = key_value_dict
def register_module_override(self, module, param_name, config):
self.module_weight_config_triple.append((module, param_name, config))
class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
self.checked_if_on_gpu = False
self.initialized = False
self.name2qmap = {}
self.mng = GlobalOptimManager.get_instance()
@ -172,7 +177,6 @@ class Optimizer8bit(torch.optim.Optimizer):
self.__setstate__({'state': state, 'param_groups': param_groups})
def to_gpu(self):
self.checked_if_on_gpu = True
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group['params']):
if p in self.state:
@ -181,6 +185,23 @@ class Optimizer8bit(torch.optim.Optimizer):
if isinstance(v, torch.Tensor):
self.state[p][k] = v.to(p.device)
def check_overrides(self):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
found = False
for gindex, group in enumerate(self.param_groups):
if found: break
for pindex, p in enumerate(group['params']):
if found: break
if id(p) == id(pmodule):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
found = True
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -196,7 +217,11 @@ class Optimizer8bit(torch.optim.Optimizer):
overflows = []
if not self.checked_if_on_gpu: self.to_gpu() # needed for fairseq pure fp16 training
if not self.initialized:
self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group['params']):
if p.grad is None:

View File

@ -6,3 +6,16 @@ If you are feeling lucky, you can also try to compile the library from source. T
__If you encounter any other error not listed here please create an issue. This will help resolve your problem and will help out others in the future.
# fatbinwrap
This error occurs if there is a mismatch between CUDA versions in the C++ library and the CUDA part. Make sure you have right CUDA in your $PATH and $LD_LIBRARY_PATH variable. In the conda base environment you can find the library under:
```bash
ls $CONDA_PREFIX/lib/*cudart*
```
Make sure this path is appended to the `LD_LIBRARY_PATH` so bnb can find the CUDA runtime environment library (cudart).
If this does not fix the issue, please try [compilation from source](compile_from_source.md) next.
If this does not work, please open an issue and paste the printed environment if you call `make` and the associated error when running bnb.

View File

@ -2,6 +2,7 @@
If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details
For global overrides in many different places in your code you can do:
```python
import torch
import bitsandbytes as bnb
@ -24,3 +25,16 @@ mng.override_config([model.special.weight, model.also_special.weight],
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
```
Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm`
For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager:
```python
class MyModule(torch.nn.Module):
def __init__(din, dout):
super(MyModule, self).__init__()
self.linear = torch.nn.Linear(din, dout)
# optimization will happen in 32-bit and
# learning rate will be set to 0.0001 independent of the main learning rate
config = {'optim_bits': 32, 'lr' : 0.0001}
GlobalOptimManager.get_instance().register_module_override(self, 'weight', config)
```

46
tests/test_modules.py Normal file
View File

@ -0,0 +1,46 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
import bitsandbytes as bnb
from itertools import product
from bitsandbytes import functional as F
@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
def test_embeddings(embcls):
bnb.optim.GlobalOptimManager.get_instance().initialize()
emb1 = torch.nn.Embedding(100, 512).cuda()
emb2 = embcls(100, 512).cuda()
adam1 = bnb.optim.Adam8bit(emb1.parameters())
adam2 = bnb.optim.Adam8bit(emb2.parameters())
batches = torch.randint(1, 100, size=(100, 4, 32)).cuda()
for i in range(100):
batch = batches[i]
embedded1 = emb1(batch)
embedded2 = emb2(batch)
l1 = embedded1.mean()
l2 = embedded2.mean()
l1.backward()
l2.backward()
adam1.step()
adam2.step()
adam1.zero_grad()
adam2.zero_grad()
assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8
assert adam2.state[emb2.weight]['state1'].dtype == torch.float32