Refactor _tile_indices into a cached property, fix device bug

This commit is contained in:
Max Ryabinin 2023-02-25 06:23:07 +01:00
parent cc608c04c2
commit d15822a54b
2 changed files with 11 additions and 22 deletions

View File

@ -223,7 +223,7 @@ matmul_cublas = MatMul8bit.apply
@dataclass
class MatmulLtState:
tile_indices: Optional[torch.Tensor] = None
_tile_indices: Optional[torch.Tensor] = None
force_no_igemmlt: bool = False
CB = None
CxB = None
@ -263,6 +263,15 @@ class MatmulLtState:
), f"please find this assert and manually enter tile size for {self.formatB}"
return (8, 32) if self.formatB == "col_turing" else (32, 32)
@property
def tile_indices(self):
if self._tile_indices is None:
device = self.CxB.device
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
with torch.no_grad():
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
return self._tile_indices
class MatMul8bitLt(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
@ -455,13 +464,6 @@ class MatMul8bitLt(torch.autograd.Function):
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CxB is not None:
if state.tile_indices is None:
order, tile_size = state.formatB, state.get_tile_size()
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
with torch.no_grad():
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
CB = (
undo_layout(state.CxB, state.tile_indices)
.to(ctx.dtype_A)

View File

@ -236,20 +236,7 @@ class Linear8bitLt(nn.Linear):
try:
if reorder_layout:
if self.state.tile_indices is None:
order, tile_size = self.state.formatB, self.state.get_tile_size()
transform = lambda x: \
bitsandbytes.functional.transform(x.to(self.weight.data.device), from_order="row",
to_order=order)[0].to(x.device)
with torch.no_grad():
self.state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(
self.state.CxB.device)
CB = (
undo_layout(self.state.CxB, self.state.tile_indices)
)
self.weight.data = CB
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
super()._save_to_state_dict(destination, prefix, keep_vars)