Add device
parameter to Linear
subclasses
This commit is contained in:
parent
ac5550a023
commit
9cac5dd1b6
|
@ -199,8 +199,8 @@ class Params4bit(torch.nn.Parameter):
|
|||
return new_param
|
||||
|
||||
class Linear4bit(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
|
||||
super().__init__(input_features, output_features, bias, device)
|
||||
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
|
@ -223,12 +223,12 @@ class Linear4bit(nn.Linear):
|
|||
return out
|
||||
|
||||
class LinearFP4(Linear4bit):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
|
||||
|
||||
class LinearNF4(Linear4bit):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4')
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
|
||||
|
||||
|
||||
|
||||
|
@ -309,8 +309,8 @@ class Int8Params(torch.nn.Parameter):
|
|||
|
||||
class Linear8bitLt(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
||||
memory_efficient_backward=False, threshold=0.0, index=None):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
|
||||
super().__init__(input_features, output_features, bias, device)
|
||||
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
@ -397,8 +397,8 @@ class Linear8bitLt(nn.Linear):
|
|||
|
||||
|
||||
class OutlierAwareLinear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
def __init__(self, input_features, output_features, bias=True, device=None):
|
||||
super().__init__(input_features, output_features, bias, device)
|
||||
self.outlier_dim = None
|
||||
self.is_quantized = False
|
||||
|
||||
|
@ -432,9 +432,10 @@ class SwitchBackLinearBnb(nn.Linear):
|
|||
memory_efficient_backward=False,
|
||||
threshold=0.0,
|
||||
index=None,
|
||||
device=None
|
||||
):
|
||||
super().__init__(
|
||||
input_features, output_features, bias
|
||||
input_features, output_features, bias, device
|
||||
)
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
|
Loading…
Reference in New Issue
Block a user