forked from mrq/bitsandbytes-rocm
commit
4e4668ab09
|
@ -23,12 +23,12 @@ Resources:
|
|||
1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)``
|
||||
2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same)
|
||||
3. There are two modes:
|
||||
- Mixed 8-bit training with 16-bit main weights. Pass the argument ``use_fp16_weights=True`` (default)
|
||||
- Int8 inference. Pass the argument ``use_fp16_weights=False``
|
||||
- Mixed 8-bit training with 16-bit main weights. Pass the argument ``has_fp16_weights=True`` (default)
|
||||
- Int8 inference. Pass the argument ``has_fp16_weights=False``
|
||||
4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``.
|
||||
```python
|
||||
# LLM.int8()
|
||||
linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, use_fp16_weights=False, threshold=6.0)
|
||||
linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, has_fp16_weights=False, threshold=6.0)
|
||||
# inputs need to be fp16
|
||||
out = linear(x.to(torch.float16))
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue
Block a user