Added is_available_triton guard.
This commit is contained in:
parent
7140c01405
commit
c3d87e4435
|
@ -184,7 +184,7 @@ class MatMulFP8Global(torch.autograd.Function):
|
|||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
|
||||
class MatMul8bitMixed(torch.autograd.Function):
|
||||
class SwitchBackBnb(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
||||
# default to pytorch behavior if inputs are empty
|
||||
|
@ -408,4 +408,4 @@ def switchback_bnb(
|
|||
state = state or MatmulLtState()
|
||||
if threshold > 0.0:
|
||||
state.threshold = threshold
|
||||
return MatMul8bitMixed.apply(A, B, out, bias, state)
|
||||
return SwitchBackBnb.apply(A, B, out, bias, state)
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
import torch
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
import torch
|
||||
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
|
||||
else:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def quantize_global_transpose(input): return None
|
||||
def quantize_global(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def quantize_rowwise(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
|
Loading…
Reference in New Issue
Block a user