Added is_available_triton guard.

This commit is contained in:
Tim Dettmers 2023-04-12 12:10:34 -07:00
parent 7140c01405
commit c3d87e4435
7 changed files with 572 additions and 536 deletions

View File

@ -184,7 +184,7 @@ class MatMulFP8Global(torch.autograd.Function):
return grad_A, grad_B, None, None, None, None, None return grad_A, grad_B, None, None, None, None, None
class MatMul8bitMixed(torch.autograd.Function): class SwitchBackBnb(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty # default to pytorch behavior if inputs are empty
@ -408,4 +408,4 @@ def switchback_bnb(
state = state or MatmulLtState() state = state or MatmulLtState()
if threshold > 0.0: if threshold > 0.0:
state.threshold = threshold state.threshold = threshold
return MatMul8bitMixed.apply(A, B, out, bias, state) return SwitchBackBnb.apply(A, B, out, bias, state)

View File

@ -1,6 +1,12 @@
import math import math
import torch import torch
import time import time
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time

View File

@ -1,4 +1,9 @@
import torch import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
else:
import triton import triton
import triton.language as tl import triton.language as tl

View File

@ -1,5 +1,10 @@
import torch import torch
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time

View File

@ -1,6 +1,12 @@
import math import math
import torch import torch
import time import time
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time

View File

@ -1,6 +1,13 @@
import math import math
import torch import torch
import time import time
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_global_transpose(input): return None
def quantize_global(x: torch.Tensor): return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time

View File

@ -1,6 +1,13 @@
import math import math
import torch import torch
import time import time
from bitsandbytes.triton.triton_utils import is_triton_available
if not is_triton_available():
def quantize_rowwise(x: torch.Tensor): return None
else:
import triton import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time