Refactor _tile_indices into a cached property, fix device bug
This commit is contained in:
parent
cc608c04c2
commit
d15822a54b
|
@ -223,7 +223,7 @@ matmul_cublas = MatMul8bit.apply
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MatmulLtState:
|
class MatmulLtState:
|
||||||
tile_indices: Optional[torch.Tensor] = None
|
_tile_indices: Optional[torch.Tensor] = None
|
||||||
force_no_igemmlt: bool = False
|
force_no_igemmlt: bool = False
|
||||||
CB = None
|
CB = None
|
||||||
CxB = None
|
CxB = None
|
||||||
|
@ -263,6 +263,15 @@ class MatmulLtState:
|
||||||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
), f"please find this assert and manually enter tile size for {self.formatB}"
|
||||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tile_indices(self):
|
||||||
|
if self._tile_indices is None:
|
||||||
|
device = self.CxB.device
|
||||||
|
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
|
||||||
|
return self._tile_indices
|
||||||
|
|
||||||
|
|
||||||
class MatMul8bitLt(torch.autograd.Function):
|
class MatMul8bitLt(torch.autograd.Function):
|
||||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||||
|
@ -455,13 +464,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
||||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||||
elif state.CxB is not None:
|
elif state.CxB is not None:
|
||||||
|
|
||||||
if state.tile_indices is None:
|
|
||||||
order, tile_size = state.formatB, state.get_tile_size()
|
|
||||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
|
|
||||||
|
|
||||||
CB = (
|
CB = (
|
||||||
undo_layout(state.CxB, state.tile_indices)
|
undo_layout(state.CxB, state.tile_indices)
|
||||||
.to(ctx.dtype_A)
|
.to(ctx.dtype_A)
|
||||||
|
|
|
@ -236,20 +236,7 @@ class Linear8bitLt(nn.Linear):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if reorder_layout:
|
if reorder_layout:
|
||||||
if self.state.tile_indices is None:
|
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
|
||||||
order, tile_size = self.state.formatB, self.state.get_tile_size()
|
|
||||||
transform = lambda x: \
|
|
||||||
bitsandbytes.functional.transform(x.to(self.weight.data.device), from_order="row",
|
|
||||||
to_order=order)[0].to(x.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
self.state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(
|
|
||||||
self.state.CxB.device)
|
|
||||||
|
|
||||||
CB = (
|
|
||||||
undo_layout(self.state.CxB, self.state.tile_indices)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.weight.data = CB
|
|
||||||
|
|
||||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user