Added is_available_triton guard to Triton SwitchBackLinear.
This commit is contained in:
parent
c3d87e4435
commit
5b612bc6df
|
@ -3,6 +3,8 @@ import torch.nn as nn
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
|
|
||||||
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
|
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
|
||||||
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
|
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
|
||||||
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||||
|
@ -160,6 +162,10 @@ class SwitchBackLinear(nn.Linear):
|
||||||
):
|
):
|
||||||
super().__init__(in_features, out_features, bias, device, dtype)
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
|
|
||||||
|
if not is_triton_available:
|
||||||
|
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
|
||||||
|
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
|
||||||
|
|
||||||
# By default, we use the global quantization.
|
# By default, we use the global quantization.
|
||||||
self.vectorize = vectorize
|
self.vectorize = vectorize
|
||||||
if self.vectorize:
|
if self.vectorize:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user