Added is_available_triton guard to Triton SwitchBackLinear.
This commit is contained in:
parent
c3d87e4435
commit
5b612bc6df
|
@ -3,6 +3,8 @@ import torch.nn as nn
|
|||
import time
|
||||
from functools import partial
|
||||
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
|
||||
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
|
||||
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||
|
@ -160,6 +162,10 @@ class SwitchBackLinear(nn.Linear):
|
|||
):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
|
||||
if not is_triton_available:
|
||||
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
|
||||
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
|
||||
|
||||
# By default, we use the global quantization.
|
||||
self.vectorize = vectorize
|
||||
if self.vectorize:
|
||||
|
|
Loading…
Reference in New Issue
Block a user