Some small changes.
This commit is contained in:
parent
6c31a5fe99
commit
69810521d3
|
@ -173,10 +173,11 @@ class FP4Params(torch.nn.Parameter):
|
|||
|
||||
|
||||
class LinearFP4(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.weight = FP4Params(self.weight.data, requires_grad=False)
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
def init_8bit_state(self):
|
||||
pass
|
||||
|
@ -191,9 +192,12 @@ class LinearFP4(nn.Linear):
|
|||
if getattr(self.weight, 'quant_state', None) is None:
|
||||
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
|
||||
inp_dtype = x.dtype
|
||||
x = x.to(torch.float16)
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.half()
|
||||
out = bnb.matmul_fp4(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
|
||||
|
||||
out = out.to(inp_dtype)
|
||||
|
||||
return out
|
||||
|
|
|
@ -21,3 +21,43 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
|
|||
|
||||
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
|
||||
return std_out, std_err
|
||||
|
||||
|
||||
|
||||
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None):
|
||||
"""
|
||||
Replace linear modules with a new Linear module.
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
linear_replacement (`torch.nn.Module`):
|
||||
The linear module that replaces the old one. Only expects standard arguments.
|
||||
If other arguments need to be passed, use a lambda.
|
||||
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
||||
List of modules names not to convert. Defaults to `lm_head`.
|
||||
copy_weights (`bool`):
|
||||
Copy the weights from the old linear module to the new one
|
||||
post_processing_fun_name (`str`):
|
||||
A function name of the replacement linear class that is called
|
||||
after processing.
|
||||
"""
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||
old_module = model._modules[name]
|
||||
model._modules[name] = linear_replacement(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
)
|
||||
if copy_weights:
|
||||
model._modules[name].weight = old_module.weight
|
||||
model._modules[name].bias = old_module.bias
|
||||
|
||||
if post_processing_function is not None:
|
||||
func = getattr(module, post_processing_function, None)
|
||||
if func is not None: func(module)
|
||||
return model
|
||||
|
||||
|
|
|
@ -2968,6 +2968,8 @@ template __global__ void kQuantizeBlockwise<half, 128, 2, 0, 1>(float * code, ha
|
|||
template __global__ void kQuantizeBlockwise<float, 128, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 64, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 64, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
//template __global__ void kQuantizeBlockwise<half, 64, 1, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
//template __global__ void kQuantizeBlockwise<float, 64, 1, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
|
|
|
@ -71,6 +71,8 @@ template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * co
|
|||
kQuantizeBlockwise<T, 128, 2, 0, FP4><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 64)
|
||||
kQuantizeBlockwise<T, 64, 2, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
//else if(blocksize == 32)
|
||||
//kQuantizeBlockwise<T, 32, 1, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
|
|
@ -1784,17 +1784,17 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
|
|||
print("partial matmul", time.time() - t0)
|
||||
|
||||
|
||||
batch_size = 1
|
||||
seqdim = 1
|
||||
batch_size = 4
|
||||
seqdim = 256
|
||||
values = []
|
||||
values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
#values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
#values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
#values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
#values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
#values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
#values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
#values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
||||
def test_bench_matmul(batch, seq, model, hidden):
|
||||
|
@ -1839,90 +1839,90 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
torch.cuda.synchronize()
|
||||
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul(A, B)
|
||||
torch.cuda.synchronize()
|
||||
print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# bnb.matmul(A, B)
|
||||
#torch.cuda.synchronize()
|
||||
#print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul(A, B, threshold=6.0)
|
||||
torch.cuda.synchronize()
|
||||
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# bnb.matmul(A, B, threshold=6.0)
|
||||
#torch.cuda.synchronize()
|
||||
#print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
|
||||
C32A, SA = F.transform(CA, "col32")
|
||||
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
|
||||
CxB, SB = F.transform(CB, to_order=formatB)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
torch.cuda.synchronize()
|
||||
print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
|
||||
#C32A, SA = F.transform(CA, "col32")
|
||||
#CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
|
||||
#CxB, SB = F.transform(CB, to_order=formatB)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
#torch.cuda.synchronize()
|
||||
#print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
BA, statsB = F.vectorwise_quant(B, dim=1)
|
||||
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
A2 = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, statsA = F.vectorwise_quant(A2, dim=1)
|
||||
C32A, SA = F.nvidia_transform(CA, "col32")
|
||||
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
|
||||
torch.cuda.synchronize()
|
||||
print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#BA, statsB = F.vectorwise_quant(B, dim=1)
|
||||
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# A2 = A.view(-1, A.shape[-1]).contiguous()
|
||||
# CA, statsA = F.vectorwise_quant(A2, dim=1)
|
||||
# C32A, SA = F.nvidia_transform(CA, "col32")
|
||||
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
|
||||
#torch.cuda.synchronize()
|
||||
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
|
||||
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
A2 = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
|
||||
C32A, SA = F.nvidia_transform(CA, "col32")
|
||||
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
out = Cout * statsB * statsA * (1.0 / (127 * 127))
|
||||
torch.cuda.synchronize()
|
||||
print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
|
||||
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# A2 = A.view(-1, A.shape[-1]).contiguous()
|
||||
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
|
||||
# C32A, SA = F.nvidia_transform(CA, "col32")
|
||||
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
|
||||
#torch.cuda.synchronize()
|
||||
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linear8bit(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
linear8bit(A)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#linear8bit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# linear8bit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#linearMixedBit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# linearMixedBit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linear8bit_train(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
linear8bit_train(A)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#linear8bit_train(A)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# linear8bit_train(A)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linear8bit_train_thresh(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
linear8bit_train(A)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#linear8bit_train_thresh(A)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# linear8bit_train(A)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
def test_zeropoint():
|
||||
def quant_zp(x):
|
||||
|
|
Loading…
Reference in New Issue
Block a user