Merge pull request #292 from justheuristic/patch-2
Support nvidia16 GPUs
This commit is contained in:
commit
72efa32962
|
@ -221,6 +221,17 @@ bmm_cublas = MatMul8bit.apply
|
|||
matmul_cublas = MatMul8bit.apply
|
||||
|
||||
|
||||
def supports_igemmlt(device: torch.device) -> bool:
|
||||
"""check if this device supports the optimized int8 kernel"""
|
||||
if torch.cuda.get_device_capability(device=device) < (7, 5):
|
||||
return False
|
||||
device_name = torch.cuda.get_device_name(device=device)
|
||||
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series
|
||||
if any(model_name in device_name for model_name in nvidia16_models):
|
||||
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatmulLtState:
|
||||
tile_indices: Optional[torch.Tensor] = None
|
||||
|
@ -270,7 +281,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
|
||||
using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt
|
||||
using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user