Extract get_tile_inds to a separate function

This commit is contained in:
Max Ryabinin 2023-06-09 21:39:37 +02:00
parent ac5550a023
commit 4fb37d45c1

View File

@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
return True
def _get_tile_size(format):
assert format in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {format}"
return (8, 32) if format == "col_turing" else (32, 32)
def get_tile_inds(format, device):
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
with torch.no_grad():
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None
@ -267,20 +280,10 @@ class MatmulLtState:
self.SBt = None
self.CBt = None
def get_tile_size(self):
assert self.formatB in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {self.formatB}"
return (8, 32) if self.formatB == "col_turing" else (32, 32)
@property
def tile_indices(self):
if self._tile_indices is None:
device = self.CxB.device
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
with torch.no_grad():
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
return self._tile_indices