forked from mrq/bitsandbytes-rocm
Also disable igemmlt for AMD GPUs
This commit is contained in:
parent
403557388d
commit
aa49b0a6cd
|
@ -270,7 +270,8 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
|
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
|
||||||
using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt
|
# Also disable igemmlt for AMD
|
||||||
|
using_igemmlt = "AMD" not in torch.cuda.get_device_name() and torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt
|
||||||
# default of pytorch behavior if inputs are empty
|
# default of pytorch behavior if inputs are empty
|
||||||
ctx.is_empty = False
|
ctx.is_empty = False
|
||||||
if prod(A.shape) == 0:
|
if prod(A.shape) == 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user