delete CxB from state
This commit is contained in:
parent
876387dc0c
commit
ef2936a90d
|
@ -260,11 +260,10 @@ class Linear8bitLt(nn.Linear):
|
||||||
|
|
||||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||||
|
|
||||||
# if not self.state.has_fp16_weights and self.state.CB is not None:
|
if not self.state.has_fp16_weights and self.state.CxB is not None:
|
||||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
# In this version, we convert 8-bit row major to turing/ampere format at each inference pass
|
||||||
# we no longer need the row-major weight
|
# Thus, we delete CxB from the state. TODO: do not store it in the state in the first place.
|
||||||
# del self.state.CB
|
del self.state.CxB
|
||||||
# self.weight.data = self.state.CxB
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user