Refactor _tile_indices into a cached property, fix device bug
This commit is contained in:
parent
cc608c04c2
commit
d15822a54b
|
@ -223,7 +223,7 @@ matmul_cublas = MatMul8bit.apply
|
|||
|
||||
@dataclass
|
||||
class MatmulLtState:
|
||||
tile_indices: Optional[torch.Tensor] = None
|
||||
_tile_indices: Optional[torch.Tensor] = None
|
||||
force_no_igemmlt: bool = False
|
||||
CB = None
|
||||
CxB = None
|
||||
|
@ -263,6 +263,15 @@ class MatmulLtState:
|
|||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
||||
|
||||
@property
|
||||
def tile_indices(self):
|
||||
if self._tile_indices is None:
|
||||
device = self.CxB.device
|
||||
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
|
||||
with torch.no_grad():
|
||||
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
|
||||
return self._tile_indices
|
||||
|
||||
|
||||
class MatMul8bitLt(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
|
@ -455,13 +464,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
elif state.CxB is not None:
|
||||
|
||||
if state.tile_indices is None:
|
||||
order, tile_size = state.formatB, state.get_tile_size()
|
||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
||||
with torch.no_grad():
|
||||
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
|
||||
|
||||
CB = (
|
||||
undo_layout(state.CxB, state.tile_indices)
|
||||
.to(ctx.dtype_A)
|
||||
|
|
|
@ -236,20 +236,7 @@ class Linear8bitLt(nn.Linear):
|
|||
|
||||
try:
|
||||
if reorder_layout:
|
||||
if self.state.tile_indices is None:
|
||||
order, tile_size = self.state.formatB, self.state.get_tile_size()
|
||||
transform = lambda x: \
|
||||
bitsandbytes.functional.transform(x.to(self.weight.data.device), from_order="row",
|
||||
to_order=order)[0].to(x.device)
|
||||
with torch.no_grad():
|
||||
self.state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(
|
||||
self.state.CxB.device)
|
||||
|
||||
CB = (
|
||||
undo_layout(self.state.CxB, self.state.tile_indices)
|
||||
)
|
||||
|
||||
self.weight.data = CB
|
||||
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
|
||||
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user