Extract get_tile_inds to a separate function
This commit is contained in:
parent
ac5550a023
commit
4fb37d45c1
|
@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _get_tile_size(format):
|
||||
assert format in (
|
||||
"col_turing",
|
||||
"col_ampere",
|
||||
), f"please find this assert and manually enter tile size for {format}"
|
||||
return (8, 32) if format == "col_turing" else (32, 32)
|
||||
|
||||
|
||||
def get_tile_inds(format, device):
|
||||
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
|
||||
with torch.no_grad():
|
||||
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
|
||||
|
||||
@dataclass
|
||||
class MatmulLtState:
|
||||
_tile_indices: Optional[torch.Tensor] = None
|
||||
|
@ -267,20 +280,10 @@ class MatmulLtState:
|
|||
self.SBt = None
|
||||
self.CBt = None
|
||||
|
||||
def get_tile_size(self):
|
||||
assert self.formatB in (
|
||||
"col_turing",
|
||||
"col_ampere",
|
||||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
||||
|
||||
@property
|
||||
def tile_indices(self):
|
||||
if self._tile_indices is None:
|
||||
device = self.CxB.device
|
||||
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
|
||||
with torch.no_grad():
|
||||
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
|
||||
self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
|
||||
return self._tile_indices
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user