Some small changes.

This commit is contained in:
Tim Dettmers 2023-03-27 09:12:57 -07:00
parent 6c31a5fe99
commit 69810521d3
5 changed files with 135 additions and 87 deletions

View File

@ -173,10 +173,11 @@ class FP4Params(torch.nn.Parameter):
class LinearFP4(nn.Linear): class LinearFP4(nn.Linear):
def __init__(self, input_features, output_features, bias=True): def __init__(self, input_features, output_features, bias=True, compute_dtype=None):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()
self.weight = FP4Params(self.weight.data, requires_grad=False) self.weight = FP4Params(self.weight.data, requires_grad=False)
self.compute_dtype = compute_dtype
def init_8bit_state(self): def init_8bit_state(self):
pass pass
@ -191,9 +192,12 @@ class LinearFP4(nn.Linear):
if getattr(self.weight, 'quant_state', None) is None: if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
inp_dtype = x.dtype inp_dtype = x.dtype
x = x.to(torch.float16) if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.half() bias = None if self.bias is None else self.bias.half()
out = bnb.matmul_fp4(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) out = bnb.matmul_fp4(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
out = out.to(inp_dtype) out = out.to(inp_dtype)
return out return out

View File

@ -21,3 +21,43 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
std_out, std_err = execute_and_return_decoded_std_streams(command_string) std_out, std_err = execute_and_return_decoded_std_streams(command_string)
return std_out, std_err return std_out, std_err
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`):
Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`):
A function name of the replacement linear class that is called
after processing.
"""
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
old_module = model._modules[name]
model._modules[name] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
)
if copy_weights:
model._modules[name].weight = old_module.weight
model._modules[name].bias = old_module.bias
if post_processing_function is not None:
func = getattr(module, post_processing_function, None)
if func is not None: func(module)
return model

View File

@ -2968,6 +2968,8 @@ template __global__ void kQuantizeBlockwise<half, 128, 2, 0, 1>(float * code, ha
template __global__ void kQuantizeBlockwise<float, 128, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise<float, 128, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 64, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise<half, 64, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 64, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise<float, 64, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
//template __global__ void kQuantizeBlockwise<half, 64, 1, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
//template __global__ void kQuantizeBlockwise<float, 64, 1, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise<half, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise<float, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);

View File

@ -71,6 +71,8 @@ template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * co
kQuantizeBlockwise<T, 128, 2, 0, FP4><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 128, 2, 0, FP4><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 64) else if(blocksize == 64)
kQuantizeBlockwise<T, 64, 2, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise<T, 64, 2, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
//else if(blocksize == 32)
//kQuantizeBlockwise<T, 32, 1, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());

View File

@ -1784,17 +1784,17 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print("partial matmul", time.time() - t0) print("partial matmul", time.time() - t0)
batch_size = 1 batch_size = 4
seqdim = 1 seqdim = 256
values = [] values = []
values.append((batch_size, seqdim, 768, 4 * 768)) values.append((batch_size, seqdim, 768, 4 * 768))
#values.append((batch_size, seqdim, 1024, 4*1024)) values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536)) values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048)) values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560)) values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096)) values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5140, 4*5140)) values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288)) values.append((batch_size, seqdim, 12288, 4*12288))
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden): def test_bench_matmul(batch, seq, model, hidden):
@ -1839,90 +1839,90 @@ def test_bench_matmul(batch, seq, model, hidden):
torch.cuda.synchronize() torch.cuda.synchronize()
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize() #torch.cuda.synchronize()
t0 = time.time() #t0 = time.time()
for i in range(iters): #for i in range(iters):
bnb.matmul(A, B) # bnb.matmul(A, B)
torch.cuda.synchronize() #torch.cuda.synchronize()
print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize() #torch.cuda.synchronize()
t0 = time.time() #t0 = time.time()
for i in range(iters): #for i in range(iters):
bnb.matmul(A, B, threshold=6.0) # bnb.matmul(A, B, threshold=6.0)
torch.cuda.synchronize() #torch.cuda.synchronize()
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
C32A, SA = F.transform(CA, "col32") #C32A, SA = F.transform(CA, "col32")
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
CxB, SB = F.transform(CB, to_order=formatB) #CxB, SB = F.transform(CB, to_order=formatB)
torch.cuda.synchronize() #torch.cuda.synchronize()
t0 = time.time() #t0 = time.time()
for i in range(iters): #for i in range(iters):
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
torch.cuda.synchronize() #torch.cuda.synchronize()
print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
BA, statsB = F.vectorwise_quant(B, dim=1) #BA, statsB = F.vectorwise_quant(B, dim=1)
CxB, SB = F.nvidia_transform(CB, to_order=formatB) #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize() #torch.cuda.synchronize()
t0 = time.time() #t0 = time.time()
for i in range(iters): #for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous() # A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1) # CA, statsA = F.vectorwise_quant(A2, dim=1)
C32A, SA = F.nvidia_transform(CA, "col32") # C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) # F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
torch.cuda.synchronize() #torch.cuda.synchronize()
print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB, SB = F.nvidia_transform(CB, to_order=formatB) #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize() #torch.cuda.synchronize()
t0 = time.time() #t0 = time.time()
for i in range(iters): #for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous() # A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
C32A, SA = F.nvidia_transform(CA, "col32") # C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
out = Cout * statsB * statsA * (1.0 / (127 * 127)) # out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch.cuda.synchronize() #torch.cuda.synchronize()
print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit(A) #linear8bit(A)
torch.cuda.synchronize() #torch.cuda.synchronize()
t0 = time.time() #t0 = time.time()
for i in range(iters): #for i in range(iters):
linear8bit(A) # linear8bit(A)
torch.cuda.synchronize() #torch.cuda.synchronize()
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linearMixedBit(A) #linearMixedBit(A)
torch.cuda.synchronize() #torch.cuda.synchronize()
t0 = time.time() #t0 = time.time()
for i in range(iters): #for i in range(iters):
linearMixedBit(A) # linearMixedBit(A)
torch.cuda.synchronize() #torch.cuda.synchronize()
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit_train(A) #linear8bit_train(A)
torch.cuda.synchronize() #torch.cuda.synchronize()
t0 = time.time() #t0 = time.time()
for i in range(iters): #for i in range(iters):
linear8bit_train(A) # linear8bit_train(A)
torch.cuda.synchronize() #torch.cuda.synchronize()
print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit_train_thresh(A) #linear8bit_train_thresh(A)
torch.cuda.synchronize() #torch.cuda.synchronize()
t0 = time.time() #t0 = time.time()
for i in range(iters): #for i in range(iters):
linear8bit_train(A) # linear8bit_train(A)
torch.cuda.synchronize() #torch.cuda.synchronize()
print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
def test_zeropoint(): def test_zeropoint():
def quant_zp(x): def quant_zp(x):