Also disable igemmlt for AMD GPUs

This commit is contained in:
0cc4m 2023-02-16 22:18:52 +01:00
parent 403557388d
commit aa49b0a6cd

View File

@ -270,7 +270,8 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt # Also disable igemmlt for AMD
using_igemmlt = "AMD" not in torch.cuda.get_device_name() and torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt
# default of pytorch behavior if inputs are empty # default of pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0: