forked from mrq/bitsandbytes-rocm
parent
3cff6795fb
commit
20e1677dfd
|
@ -43,7 +43,13 @@ Docs:
|
|||
Features:
|
||||
- Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer.
|
||||
- Added AdamW (copy of Adam with weight decay init 1e-2). #10
|
||||
- Introduced ModuleConfig overrides which can be seamlessly be used at initialization time of a module.
|
||||
- Added `bnb.nn.Embedding` layer which runs at 32-bit but without the layernorm. This works well if you need to fine-tune pretrained models that do not have a embedding layer norm. #19
|
||||
|
||||
Bug fixes:
|
||||
- Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13
|
||||
- Fixed an unsafe use of eval. #8
|
||||
- Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15
|
||||
|
||||
Docs:
|
||||
- Added instructions how to solve "\_\_fatbinwrap_" errors.
|
||||
|
|
|
@ -83,6 +83,7 @@ For upcoming features and changes and full history see [Patch Notes](CHANGELOG.m
|
|||
## Errors
|
||||
|
||||
1. RuntimeError: CUDA error: no kernel image is available for execution on the device. [Solution](errors_and_solutions.md#No-kernel-image-available)
|
||||
2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_)
|
||||
|
||||
## Compile from source
|
||||
|
||||
|
|
|
@ -2,4 +2,4 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from .modules import StableEmbedding
|
||||
from .modules import StableEmbedding, Embedding
|
||||
|
|
|
@ -18,8 +18,7 @@ class StableEmbedding(torch.nn.Embedding):
|
|||
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
|
||||
super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
|
||||
self.norm = torch.nn.LayerNorm(embedding_dim)
|
||||
GlobalOptimManager.get_instance().register_parameters(self.weight)
|
||||
GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32)
|
||||
GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
torch.nn.init.xavier_uniform_(self.weight)
|
||||
|
@ -42,3 +41,33 @@ class StableEmbedding(torch.nn.Embedding):
|
|||
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
return self.norm(emb)
|
||||
|
||||
|
||||
class Embedding(torch.nn.Embedding):
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
|
||||
super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
|
||||
GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
torch.nn.init.xavier_uniform_(self.weight)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
|
||||
to make the Layer compatible with Pytorch < 1.9.
|
||||
This means that if this changes in future PyTorch releases this need to change too
|
||||
which is cumbersome. However, with this we can ensure compatibility with previous
|
||||
PyTorch releases.
|
||||
'''
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
emb = F.embedding(
|
||||
input, self.weight, self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
return emb
|
||||
|
|
|
@ -26,6 +26,7 @@ class GlobalOptimManager(object):
|
|||
self.index2config = {}
|
||||
self.optimizer = None
|
||||
self.uses_config_override = False
|
||||
self.module_weight_config_triple = []
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
|
@ -77,12 +78,16 @@ class GlobalOptimManager(object):
|
|||
if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
|
||||
else: self.pid2config[id(p)] = key_value_dict
|
||||
|
||||
def register_module_override(self, module, param_name, config):
|
||||
self.module_weight_config_triple.append((module, param_name, config))
|
||||
|
||||
|
||||
|
||||
class Optimizer8bit(torch.optim.Optimizer):
|
||||
|
||||
def __init__(self, params, defaults, optim_bits=32):
|
||||
super(Optimizer8bit, self).__init__(params, defaults)
|
||||
self.checked_if_on_gpu = False
|
||||
self.initialized = False
|
||||
self.name2qmap = {}
|
||||
|
||||
self.mng = GlobalOptimManager.get_instance()
|
||||
|
@ -172,7 +177,6 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
self.__setstate__({'state': state, 'param_groups': param_groups})
|
||||
|
||||
def to_gpu(self):
|
||||
self.checked_if_on_gpu = True
|
||||
for gindex, group in enumerate(self.param_groups):
|
||||
for pindex, p in enumerate(group['params']):
|
||||
if p in self.state:
|
||||
|
@ -181,6 +185,23 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
if isinstance(v, torch.Tensor):
|
||||
self.state[p][k] = v.to(p.device)
|
||||
|
||||
def check_overrides(self):
|
||||
for module, attr, config in self.mng.module_weight_config_triple:
|
||||
pmodule = getattr(module, attr)
|
||||
assert pmodule is not None
|
||||
assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
|
||||
found = False
|
||||
for gindex, group in enumerate(self.param_groups):
|
||||
if found: break
|
||||
for pindex, p in enumerate(group['params']):
|
||||
if found: break
|
||||
if id(p) == id(pmodule):
|
||||
# found the matching parameter
|
||||
# init override
|
||||
self.mng.pid2config[id(p)] = config
|
||||
self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
|
||||
found = True
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
@ -196,7 +217,11 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
|
||||
overflows = []
|
||||
|
||||
if not self.checked_if_on_gpu: self.to_gpu() # needed for fairseq pure fp16 training
|
||||
if not self.initialized:
|
||||
self.check_overrides()
|
||||
self.to_gpu() # needed for fairseq pure fp16 training
|
||||
self.initialized = True
|
||||
|
||||
for gindex, group in enumerate(self.param_groups):
|
||||
for pindex, p in enumerate(group['params']):
|
||||
if p.grad is None:
|
||||
|
|
|
@ -6,3 +6,16 @@ If you are feeling lucky, you can also try to compile the library from source. T
|
|||
|
||||
|
||||
__If you encounter any other error not listed here please create an issue. This will help resolve your problem and will help out others in the future.
|
||||
|
||||
|
||||
# fatbinwrap
|
||||
|
||||
This error occurs if there is a mismatch between CUDA versions in the C++ library and the CUDA part. Make sure you have right CUDA in your $PATH and $LD_LIBRARY_PATH variable. In the conda base environment you can find the library under:
|
||||
```bash
|
||||
ls $CONDA_PREFIX/lib/*cudart*
|
||||
```
|
||||
Make sure this path is appended to the `LD_LIBRARY_PATH` so bnb can find the CUDA runtime environment library (cudart).
|
||||
|
||||
If this does not fix the issue, please try [compilation from source](compile_from_source.md) next.
|
||||
|
||||
If this does not work, please open an issue and paste the printed environment if you call `make` and the associated error when running bnb.
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details
|
||||
|
||||
For global overrides in many different places in your code you can do:
|
||||
```python
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
|
@ -24,3 +25,16 @@ mng.override_config([model.special.weight, model.also_special.weight],
|
|||
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
|
||||
```
|
||||
Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm`
|
||||
|
||||
For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager:
|
||||
```python
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(din, dout):
|
||||
super(MyModule, self).__init__()
|
||||
self.linear = torch.nn.Linear(din, dout)
|
||||
# optimization will happen in 32-bit and
|
||||
# learning rate will be set to 0.0001 independent of the main learning rate
|
||||
config = {'optim_bits': 32, 'lr' : 0.0001}
|
||||
GlobalOptimManager.get_instance().register_module_override(self, 'weight', config)
|
||||
|
||||
```
|
||||
|
|
46
tests/test_modules.py
Normal file
46
tests/test_modules.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from itertools import product
|
||||
|
||||
from bitsandbytes import functional as F
|
||||
|
||||
|
||||
@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
|
||||
def test_embeddings(embcls):
|
||||
bnb.optim.GlobalOptimManager.get_instance().initialize()
|
||||
emb1 = torch.nn.Embedding(100, 512).cuda()
|
||||
emb2 = embcls(100, 512).cuda()
|
||||
|
||||
adam1 = bnb.optim.Adam8bit(emb1.parameters())
|
||||
adam2 = bnb.optim.Adam8bit(emb2.parameters())
|
||||
|
||||
batches = torch.randint(1, 100, size=(100, 4, 32)).cuda()
|
||||
|
||||
for i in range(100):
|
||||
batch = batches[i]
|
||||
|
||||
embedded1 = emb1(batch)
|
||||
embedded2 = emb2(batch)
|
||||
|
||||
l1 = embedded1.mean()
|
||||
l2 = embedded2.mean()
|
||||
|
||||
l1.backward()
|
||||
l2.backward()
|
||||
|
||||
adam1.step()
|
||||
adam2.step()
|
||||
|
||||
adam1.zero_grad()
|
||||
adam2.zero_grad()
|
||||
|
||||
assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8
|
||||
assert adam2.state[emb2.weight]['state1'].dtype == torch.float32
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user