Added is_available_triton guard to Triton SwitchBackLinear.

This commit is contained in:
Tim Dettmers 2023-04-12 12:16:55 -07:00
parent c3d87e4435
commit 5b612bc6df

View File

@ -3,6 +3,8 @@ import torch.nn as nn
import time
from functools import partial
from bitsandbytes.triton.triton_utils import is_triton_available
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
@ -160,6 +162,10 @@ class SwitchBackLinear(nn.Linear):
):
super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available:
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
# By default, we use the global quantization.
self.vectorize = vectorize
if self.vectorize: