Add device parameter to Linear subclasses

This commit is contained in:
shadeMe 2023-06-01 17:43:30 +02:00
parent ac5550a023
commit 9cac5dd1b6
No known key found for this signature in database
GPG Key ID: 6FCA9FC635B2A402

View File

@ -199,8 +199,8 @@ class Params4bit(torch.nn.Parameter):
return new_param
class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'):
super().__init__(input_features, output_features, bias)
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
self.compute_dtype = compute_dtype
@ -223,12 +223,12 @@ class Linear4bit(nn.Linear):
return out
class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
class LinearNF4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4')
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
@ -309,8 +309,8 @@ class Int8Params(torch.nn.Parameter):
class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None):
super().__init__(input_features, output_features, bias)
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
super().__init__(input_features, output_features, bias, device)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
@ -397,8 +397,8 @@ class Linear8bitLt(nn.Linear):
class OutlierAwareLinear(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
def __init__(self, input_features, output_features, bias=True, device=None):
super().__init__(input_features, output_features, bias, device)
self.outlier_dim = None
self.is_quantized = False
@ -432,9 +432,10 @@ class SwitchBackLinearBnb(nn.Linear):
memory_efficient_backward=False,
threshold=0.0,
index=None,
device=None
):
super().__init__(
input_features, output_features, bias
input_features, output_features, bias, device
)
self.state = bnb.MatmulLtState()
self.index = index