Extract get_tile_inds to a separate function
This commit is contained in:
parent
ac5550a023
commit
4fb37d45c1
|
@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tile_size(format):
|
||||||
|
assert format in (
|
||||||
|
"col_turing",
|
||||||
|
"col_ampere",
|
||||||
|
), f"please find this assert and manually enter tile size for {format}"
|
||||||
|
return (8, 32) if format == "col_turing" else (32, 32)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tile_inds(format, device):
|
||||||
|
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MatmulLtState:
|
class MatmulLtState:
|
||||||
_tile_indices: Optional[torch.Tensor] = None
|
_tile_indices: Optional[torch.Tensor] = None
|
||||||
|
@ -267,20 +280,10 @@ class MatmulLtState:
|
||||||
self.SBt = None
|
self.SBt = None
|
||||||
self.CBt = None
|
self.CBt = None
|
||||||
|
|
||||||
def get_tile_size(self):
|
|
||||||
assert self.formatB in (
|
|
||||||
"col_turing",
|
|
||||||
"col_ampere",
|
|
||||||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
|
||||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tile_indices(self):
|
def tile_indices(self):
|
||||||
if self._tile_indices is None:
|
if self._tile_indices is None:
|
||||||
device = self.CxB.device
|
self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
|
||||||
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
|
|
||||||
return self._tile_indices
|
return self._tile_indices
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user