Merge pull request #469 from shadeMe/linear-layer-device
Add `device` parameter to `Linear` subclasses and `Embedding`
This commit is contained in:
commit
196d6f5dc1
|
@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding):
|
|||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
_weight: Optional[Tensor] = None,
|
||||
device: Optional[device] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_embeddings,
|
||||
|
@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding):
|
|||
scale_grad_by_freq,
|
||||
sparse,
|
||||
_weight,
|
||||
device=device
|
||||
)
|
||||
GlobalOptimManager.get_instance().register_module_override(
|
||||
self, "weight", {"optim_bits": 32}
|
||||
|
@ -199,8 +201,8 @@ class Params4bit(torch.nn.Parameter):
|
|||
return new_param
|
||||
|
||||
class Linear4bit(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
|
||||
super().__init__(input_features, output_features, bias, device)
|
||||
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
|
@ -223,12 +225,12 @@ class Linear4bit(nn.Linear):
|
|||
return out
|
||||
|
||||
class LinearFP4(Linear4bit):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
|
||||
|
||||
class LinearNF4(Linear4bit):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4')
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
|
||||
|
||||
|
||||
|
||||
|
@ -320,8 +322,8 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
|
|||
|
||||
class Linear8bitLt(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
||||
memory_efficient_backward=False, threshold=0.0, index=None):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
|
||||
super().__init__(input_features, output_features, bias, device)
|
||||
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
@ -411,8 +413,8 @@ class Linear8bitLt(nn.Linear):
|
|||
|
||||
|
||||
class OutlierAwareLinear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
def __init__(self, input_features, output_features, bias=True, device=None):
|
||||
super().__init__(input_features, output_features, bias, device)
|
||||
self.outlier_dim = None
|
||||
self.is_quantized = False
|
||||
|
||||
|
@ -446,9 +448,10 @@ class SwitchBackLinearBnb(nn.Linear):
|
|||
memory_efficient_backward=False,
|
||||
threshold=0.0,
|
||||
index=None,
|
||||
device=None
|
||||
):
|
||||
super().__init__(
|
||||
input_features, output_features, bias
|
||||
input_features, output_features, bias, device
|
||||
)
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
|
Loading…
Reference in New Issue
Block a user