Reduce diff

This commit is contained in:
Max Ryabinin 2023-02-25 06:24:58 +01:00
parent d15822a54b
commit 24609b66af

View File

@ -212,7 +212,7 @@ class Int8Params(torch.nn.Parameter):
class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None):
memory_efficient_backward=False, threshold=0.0, index=None):
super().__init__(input_features, output_features, bias)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()