Merge pull request #469 from shadeMe/linear-layer-device
Add `device` parameter to `Linear` subclasses and `Embedding`
This commit is contained in:
commit
196d6f5dc1
|
@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding):
|
||||||
scale_grad_by_freq: bool = False,
|
scale_grad_by_freq: bool = False,
|
||||||
sparse: bool = False,
|
sparse: bool = False,
|
||||||
_weight: Optional[Tensor] = None,
|
_weight: Optional[Tensor] = None,
|
||||||
|
device: Optional[device] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_embeddings,
|
num_embeddings,
|
||||||
|
@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding):
|
||||||
scale_grad_by_freq,
|
scale_grad_by_freq,
|
||||||
sparse,
|
sparse,
|
||||||
_weight,
|
_weight,
|
||||||
|
device=device
|
||||||
)
|
)
|
||||||
GlobalOptimManager.get_instance().register_module_override(
|
GlobalOptimManager.get_instance().register_module_override(
|
||||||
self, "weight", {"optim_bits": 32}
|
self, "weight", {"optim_bits": 32}
|
||||||
|
@ -199,8 +201,8 @@ class Params4bit(torch.nn.Parameter):
|
||||||
return new_param
|
return new_param
|
||||||
|
|
||||||
class Linear4bit(nn.Linear):
|
class Linear4bit(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'):
|
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias, device)
|
||||||
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
|
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
|
||||||
self.compute_dtype = compute_dtype
|
self.compute_dtype = compute_dtype
|
||||||
|
|
||||||
|
@ -223,12 +225,12 @@ class Linear4bit(nn.Linear):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class LinearFP4(Linear4bit):
|
class LinearFP4(Linear4bit):
|
||||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
|
||||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')
|
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
|
||||||
|
|
||||||
class LinearNF4(Linear4bit):
|
class LinearNF4(Linear4bit):
|
||||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
|
||||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4')
|
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -320,8 +322,8 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
|
||||||
|
|
||||||
class Linear8bitLt(nn.Linear):
|
class Linear8bitLt(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
||||||
memory_efficient_backward=False, threshold=0.0, index=None):
|
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias, device)
|
||||||
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||||
self.state = bnb.MatmulLtState()
|
self.state = bnb.MatmulLtState()
|
||||||
self.index = index
|
self.index = index
|
||||||
|
@ -411,8 +413,8 @@ class Linear8bitLt(nn.Linear):
|
||||||
|
|
||||||
|
|
||||||
class OutlierAwareLinear(nn.Linear):
|
class OutlierAwareLinear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True):
|
def __init__(self, input_features, output_features, bias=True, device=None):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias, device)
|
||||||
self.outlier_dim = None
|
self.outlier_dim = None
|
||||||
self.is_quantized = False
|
self.is_quantized = False
|
||||||
|
|
||||||
|
@ -446,9 +448,10 @@ class SwitchBackLinearBnb(nn.Linear):
|
||||||
memory_efficient_backward=False,
|
memory_efficient_backward=False,
|
||||||
threshold=0.0,
|
threshold=0.0,
|
||||||
index=None,
|
index=None,
|
||||||
|
device=None
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_features, output_features, bias
|
input_features, output_features, bias, device
|
||||||
)
|
)
|
||||||
self.state = bnb.MatmulLtState()
|
self.state = bnb.MatmulLtState()
|
||||||
self.index = index
|
self.index = index
|
||||||
|
|
Loading…
Reference in New Issue
Block a user