Merge pull request #469 from shadeMe/linear-layer-device

Add `device` parameter to `Linear` subclasses and `Embedding`
This commit is contained in:
Tim Dettmers 2023-07-10 06:17:13 -07:00 committed by GitHub
commit 196d6f5dc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
device: Optional[device] = None,
) -> None:
super().__init__(
num_embeddings,
@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq,
sparse,
_weight,
device=device
)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
@ -199,8 +201,8 @@ class Params4bit(torch.nn.Parameter):
return new_param
class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'):
super().__init__(input_features, output_features, bias)
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
self.compute_dtype = compute_dtype
@ -223,12 +225,12 @@ class Linear4bit(nn.Linear):
return out
class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
class LinearNF4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4')
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
@ -320,8 +322,8 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None):
super().__init__(input_features, output_features, bias)
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
super().__init__(input_features, output_features, bias, device)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
@ -411,8 +413,8 @@ class Linear8bitLt(nn.Linear):
class OutlierAwareLinear(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
def __init__(self, input_features, output_features, bias=True, device=None):
super().__init__(input_features, output_features, bias, device)
self.outlier_dim = None
self.is_quantized = False
@ -446,9 +448,10 @@ class SwitchBackLinearBnb(nn.Linear):
memory_efficient_backward=False,
threshold=0.0,
index=None,
device=None
):
super().__init__(
input_features, output_features, bias
input_features, output_features, bias, device
)
self.state = bnb.MatmulLtState()
self.index = index