Added is_available_triton guard.
This commit is contained in:
parent
7140c01405
commit
c3d87e4435
|
@ -184,7 +184,7 @@ class MatMulFP8Global(torch.autograd.Function):
|
||||||
return grad_A, grad_B, None, None, None, None, None
|
return grad_A, grad_B, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class MatMul8bitMixed(torch.autograd.Function):
|
class SwitchBackBnb(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
||||||
# default to pytorch behavior if inputs are empty
|
# default to pytorch behavior if inputs are empty
|
||||||
|
@ -408,4 +408,4 @@ def switchback_bnb(
|
||||||
state = state or MatmulLtState()
|
state = state or MatmulLtState()
|
||||||
if threshold > 0.0:
|
if threshold > 0.0:
|
||||||
state.threshold = threshold
|
state.threshold = threshold
|
||||||
return MatMul8bitMixed.apply(A, B, out, bias, state)
|
return SwitchBackBnb.apply(A, B, out, bias, state)
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
|
|
||||||
|
if not is_triton_available():
|
||||||
|
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
|
||||||
|
else:
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
|
@ -1,4 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
|
|
||||||
|
if not is_triton_available():
|
||||||
|
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
|
||||||
|
else:
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
|
@ -1,5 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
|
|
||||||
|
if not is_triton_available():
|
||||||
|
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
|
||||||
|
else:
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
|
|
||||||
|
if not is_triton_available():
|
||||||
|
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
|
||||||
|
else:
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
|
|
||||||
|
if not is_triton_available():
|
||||||
|
def quantize_global_transpose(input): return None
|
||||||
|
def quantize_global(x: torch.Tensor): return None
|
||||||
|
else:
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
|
|
||||||
|
if not is_triton_available():
|
||||||
|
def quantize_rowwise(x: torch.Tensor): return None
|
||||||
|
else:
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
Loading…
Reference in New Issue
Block a user