diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 376fb8a..4db9a92 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -221,6 +221,17 @@ bmm_cublas = MatMul8bit.apply matmul_cublas = MatMul8bit.apply +def supports_igemmlt(device: torch.device) -> bool: + """check if this device supports the optimized int8 kernel""" + if torch.cuda.get_device_capability(device=device) < (7, 5): + return False + device_name = torch.cuda.get_device_name(device=device) + nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series + if any(model_name in device_name for model_name in nvidia16_models): + return False # these devices are technically cuda 7.5-capable, but they lack tensor cores + return True + + @dataclass class MatmulLtState: tile_indices: Optional[torch.Tensor] = None @@ -270,7 +281,7 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): - using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt + using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: