From 5b612bc6dfa131fb0cb27dcae5fd863c15694328 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 12 Apr 2023 12:16:55 -0700 Subject: [PATCH] Added is_available_triton guard to Triton SwitchBackLinear. --- bitsandbytes/nn/triton_based_modules.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index 61e9053..7794fa0 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -3,6 +3,8 @@ import torch.nn as nn import time from functools import partial +from bitsandbytes.triton.triton_utils import is_triton_available + from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise from bitsandbytes.triton.quantize_rowwise import quantize_rowwise from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose @@ -160,6 +162,10 @@ class SwitchBackLinear(nn.Linear): ): super().__init__(in_features, out_features, bias, device, dtype) + if not is_triton_available: + raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. + Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') + # By default, we use the global quantization. self.vectorize = vectorize if self.vectorize: