diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 376fb8a..f2c62da 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -270,7 +270,8 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): - using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt + # Also disable igemmlt for AMD + using_igemmlt = "AMD" not in torch.cuda.get_device_name() and torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: