diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 5d6d19c..a550ec1 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -173,10 +173,11 @@ class FP4Params(torch.nn.Parameter): class LinearFP4(nn.Linear): - def __init__(self, input_features, output_features, bias=True): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None): super().__init__(input_features, output_features, bias) self.state = bnb.MatmulLtState() self.weight = FP4Params(self.weight.data, requires_grad=False) + self.compute_dtype = compute_dtype def init_8bit_state(self): pass @@ -191,9 +192,12 @@ class LinearFP4(nn.Linear): if getattr(self.weight, 'quant_state', None) is None: print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') inp_dtype = x.dtype - x = x.to(torch.float16) + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + bias = None if self.bias is None else self.bias.half() out = bnb.matmul_fp4(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) + out = out.to(inp_dtype) return out diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 1cd90e3..d6cc966 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -21,3 +21,43 @@ def execute_and_return(command_string: str) -> Tuple[str, str]: std_out, std_err = execute_and_return_decoded_std_streams(command_string) return std_out, std_err + + + +def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): + """ + Replace linear modules with a new Linear module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + linear_replacement (`torch.nn.Module`): + The linear module that replaces the old one. Only expects standard arguments. + If other arguments need to be passed, use a lambda. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + copy_weights (`bool`): + Copy the weights from the old linear module to the new one + post_processing_fun_name (`str`): + A function name of the replacement linear class that is called + after processing. + """ + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) + + if isinstance(module, torch.nn.Linear) and name not in skip_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight = old_module.weight + model._modules[name].bias = old_module.bias + + if post_processing_function is not None: + func = getattr(module, post_processing_function, None) + if func is not None: func(module) + return model + diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a1eec68..a2691be 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2968,6 +2968,8 @@ template __global__ void kQuantizeBlockwise(float * code, ha template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +//template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +//template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 483d915..07ef850 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -71,6 +71,8 @@ template void quantizeBlockwise(float * co kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 64) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + //else if(blocksize == 32) + //kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); diff --git a/tests/test_functional.py b/tests/test_functional.py index 23b7558..54cecca 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1784,17 +1784,17 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 1 -seqdim = 1 +batch_size = 4 +seqdim = 256 values = [] values.append((batch_size, seqdim, 768, 4 * 768)) -#values.append((batch_size, seqdim, 1024, 4*1024)) -#values.append((batch_size, seqdim, 1536, 4*1536)) -#values.append((batch_size, seqdim, 2048, 4*2048)) -#values.append((batch_size, seqdim, 2560, 4*2560)) -#values.append((batch_size, seqdim, 4096, 4*4096)) -#values.append((batch_size, seqdim, 5140, 4*5140)) -#values.append((batch_size, seqdim, 12288, 4*12288)) +values.append((batch_size, seqdim, 1024, 4*1024)) +values.append((batch_size, seqdim, 1536, 4*1536)) +values.append((batch_size, seqdim, 2048, 4*2048)) +values.append((batch_size, seqdim, 2560, 4*2560)) +values.append((batch_size, seqdim, 4096, 4*4096)) +values.append((batch_size, seqdim, 5140, 4*5140)) +values.append((batch_size, seqdim, 12288, 4*12288)) names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): @@ -1839,90 +1839,90 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul(A, B) - torch.cuda.synchronize() - print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul(A, B) + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul(A, B, threshold=6.0) - torch.cuda.synchronize() - print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul(A, B, threshold=6.0) + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - C32A, SA = F.transform(CA, "col32") - CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) - CxB, SB = F.transform(CB, to_order=formatB) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - torch.cuda.synchronize() - print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + #C32A, SA = F.transform(CA, "col32") + #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + #CxB, SB = F.transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + #torch.cuda.synchronize() + #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - BA, statsB = F.vectorwise_quant(B, dim=1) - CxB, SB = F.nvidia_transform(CB, to_order=formatB) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - A2 = A.view(-1, A.shape[-1]).contiguous() - CA, statsA = F.vectorwise_quant(A2, dim=1) - C32A, SA = F.nvidia_transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) - F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) - torch.cuda.synchronize() - print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #BA, statsB = F.vectorwise_quant(B, dim=1) + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1) + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) + #torch.cuda.synchronize() + #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") - CxB, SB = F.nvidia_transform(CB, to_order=formatB) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - A2 = A.view(-1, A.shape[-1]).contiguous() - CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") - C32A, SA = F.nvidia_transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) - out = Cout * statsB * statsA * (1.0 / (127 * 127)) - torch.cuda.synchronize() - print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # out = Cout * statsB * statsA * (1.0 / (127 * 127)) + #torch.cuda.synchronize() + #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linear8bit(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linear8bit(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linear8bit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linearMixedBit(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linearMixedBit(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linearMixedBit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linearMixedBit(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linear8bit_train(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linear8bit_train(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linear8bit_train(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit_train(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linear8bit_train_thresh(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linear8bit_train(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linear8bit_train_thresh(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit_train(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") def test_zeropoint(): def quant_zp(x):