forked from mrq/bitsandbytes-rocm
Merge pull request #33 from dbaranchuk/memory-efficient-backward
Memory efficient backward
This commit is contained in:
commit
439f2b0c10
|
@ -1,4 +1,6 @@
|
||||||
import operator
|
import operator
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import bitsandbytes.functional as F
|
import bitsandbytes.functional as F
|
||||||
|
|
||||||
|
@ -184,6 +186,7 @@ class MatmulLtState:
|
||||||
idx = None
|
idx = None
|
||||||
is_training = True
|
is_training = True
|
||||||
has_fp16_weights = True
|
has_fp16_weights = True
|
||||||
|
memory_efficient_backward = False
|
||||||
use_pool = False
|
use_pool = False
|
||||||
formatB = F.get_special_format_str()
|
formatB = F.get_special_format_str()
|
||||||
|
|
||||||
|
@ -209,31 +212,29 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
ctx.B = B
|
ctx.B = B
|
||||||
ctx.bias = bias
|
ctx.bias = bias
|
||||||
if A.shape[-1] == B.shape[0]:
|
if A.shape[-1] == B.shape[0]:
|
||||||
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
|
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
|
||||||
else:
|
else:
|
||||||
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
|
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
|
||||||
|
|
||||||
# 1. Quantize A
|
# 1. Quantize A
|
||||||
# 2. Quantize B
|
# 2. Quantize B
|
||||||
# 3. Matmul
|
# 3. Matmul
|
||||||
# 4. Mixed-precision decomposition matmul
|
# 4. Mixed-precision decomposition matmul
|
||||||
# 5. Save state
|
# 5. Save state
|
||||||
requires_gradA = A.requires_grad
|
|
||||||
requires_gradB = B.requires_grad
|
|
||||||
requires_gradBias = bias is not None and bias.requires_grad
|
|
||||||
formatB = state.formatB
|
formatB = state.formatB
|
||||||
input_shape = A.shape
|
input_shape = A.shape
|
||||||
if state.outlier_pool is None:
|
if state.outlier_pool is None:
|
||||||
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
||||||
assert (
|
|
||||||
A.dtype == torch.float16
|
# Cast A to fp16
|
||||||
), f"The input data type needs to be fp16 but {A.dtype} was found!"
|
if A.dtype != torch.float16:
|
||||||
|
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
|
||||||
|
|
||||||
# 1. Quantize A
|
# 1. Quantize A
|
||||||
if len(A.shape) == 3:
|
if len(A.shape) == 3:
|
||||||
A = A.view(-1, A.shape[-1]).contiguous()
|
A = A.view(-1, A.shape[-1]).contiguous()
|
||||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
|
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
|
||||||
A, threshold=state.threshold
|
A.to(torch.float16), threshold=state.threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
if state.threshold > 0.0 and coo_tensorA is not None:
|
if state.threshold > 0.0 and coo_tensorA is not None:
|
||||||
|
@ -269,7 +270,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
state.SCB,
|
state.SCB,
|
||||||
state.SCBt,
|
state.SCBt,
|
||||||
coo_tensorB,
|
coo_tensorB,
|
||||||
) = F.double_quant(B)
|
) = F.double_quant(B.to(torch.float16))
|
||||||
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
||||||
else:
|
else:
|
||||||
has_grad = False
|
has_grad = False
|
||||||
|
@ -290,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
(outliers * state.SCB.view(-1, 1) / 127.0)
|
(outliers * state.SCB.view(-1, 1) / 127.0)
|
||||||
.t()
|
.t()
|
||||||
.contiguous()
|
.contiguous()
|
||||||
.half()
|
.to(A.dtype)
|
||||||
)
|
)
|
||||||
CA[:, state.idx.long()] = 0
|
CA[:, state.idx.long()] = 0
|
||||||
CAt[:, state.idx.long()] = 0
|
CAt[:, state.idx.long()] = 0
|
||||||
|
@ -307,7 +308,13 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
C32A, SA = F.transform(CA, "col32")
|
C32A, SA = F.transform(CA, "col32")
|
||||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||||
# we apply the fused bias here
|
# we apply the fused bias here
|
||||||
|
|
||||||
|
if bias is None or bias.dtype == torch.float16:
|
||||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||||
|
output = output.to(A.dtype)
|
||||||
|
else: # apply bias separately
|
||||||
|
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||||
|
output = output.to(A.dtype).add_(bias)
|
||||||
|
|
||||||
# 4. Mixed-precision decomposition matmul
|
# 4. Mixed-precision decomposition matmul
|
||||||
if coo_tensorA is not None and subA is not None:
|
if coo_tensorA is not None and subA is not None:
|
||||||
|
@ -318,9 +325,9 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
|
|
||||||
ctx.formatB = formatB
|
ctx.formatB = formatB
|
||||||
ctx.grad_shape = input_shape
|
ctx.grad_shape = input_shape
|
||||||
ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias]
|
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
||||||
|
|
||||||
if requires_gradA or requires_gradB:
|
if any(ctx.needs_input_grad[:2]):
|
||||||
ctx.tensors = (CAt, subA)
|
ctx.tensors = (CAt, subA)
|
||||||
ctx.tensor_states = (SCAt, state.idx)
|
ctx.tensor_states = (SCAt, state.idx)
|
||||||
else:
|
else:
|
||||||
|
@ -328,8 +335,8 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
ctx.tensor_states = (None, None)
|
ctx.tensor_states = (None, None)
|
||||||
ctx.save_for_backward(None, None)
|
ctx.save_for_backward(None, None)
|
||||||
|
|
||||||
|
|
||||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||||
#clone_func = torch.clone
|
|
||||||
return clone_func(output.view(output_shape))
|
return clone_func(output.view(output_shape))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -337,23 +344,24 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
if ctx.is_empty:
|
if ctx.is_empty:
|
||||||
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||||
req_gradA, req_gradB, req_gradBias = ctx.req_grads
|
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
|
||||||
CAt, subA = ctx.tensors
|
CAt, subA = ctx.tensors
|
||||||
SCAt, idx = ctx.tensor_states
|
SCAt, idx = ctx.tensor_states
|
||||||
formatB = ctx.formatB
|
formatB = ctx.formatB
|
||||||
state = ctx.state
|
state = ctx.state
|
||||||
assert (
|
grad_A = grad_B = grad_bias = None
|
||||||
state.has_fp16_weights
|
|
||||||
), "Backprop only supported for fp16 weights."
|
|
||||||
|
|
||||||
|
if req_gradBias:
|
||||||
|
# compute grad_bias first before changing grad_output dtype
|
||||||
|
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
||||||
|
|
||||||
|
# Cast grad_output to fp16
|
||||||
if len(grad_output.shape) == 3:
|
if len(grad_output.shape) == 3:
|
||||||
grad_output = grad_output.view(
|
grad_output = grad_output.reshape(
|
||||||
-1, grad_output.shape[-1]
|
-1, grad_output.shape[-1]
|
||||||
).contiguous()
|
).contiguous()
|
||||||
|
|
||||||
grad_A = grad_B = grad_bias = None
|
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
||||||
|
|
||||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
|
|
||||||
if req_gradB:
|
if req_gradB:
|
||||||
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
||||||
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
|
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
|
||||||
|
@ -363,16 +371,20 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
|
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
|
||||||
|
|
||||||
if req_gradA:
|
if req_gradA:
|
||||||
|
if state.CBt is not None:
|
||||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||||
if state.CxBt is None:
|
if state.CxBt is None:
|
||||||
state.CxBt, state.SBt = F.transform(
|
state.CxBt, state.SBt = F.transform(
|
||||||
state.CBt, to_order=formatB, transpose=True
|
state.CBt, to_order=formatB, transpose=True
|
||||||
)
|
)
|
||||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
|
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||||
|
|
||||||
if req_gradBias:
|
elif state.CB is not None:
|
||||||
grad_bias = grad_output.sum(0)
|
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
|
||||||
|
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||||
|
else:
|
||||||
|
raise Exception('State must contain either CBt or CB matrix for backward')
|
||||||
|
|
||||||
return grad_A, grad_B, None, grad_bias, None
|
return grad_A, grad_B, None, grad_bias, None
|
||||||
|
|
||||||
|
|
|
@ -221,6 +221,7 @@ class Linear8bitLt(nn.Linear):
|
||||||
output_features,
|
output_features,
|
||||||
bias=True,
|
bias=True,
|
||||||
has_fp16_weights=True,
|
has_fp16_weights=True,
|
||||||
|
memory_efficient_backward=False,
|
||||||
threshold=0.0,
|
threshold=0.0,
|
||||||
index=None,
|
index=None,
|
||||||
):
|
):
|
||||||
|
@ -232,10 +233,13 @@ class Linear8bitLt(nn.Linear):
|
||||||
|
|
||||||
self.state.threshold = threshold
|
self.state.threshold = threshold
|
||||||
self.state.has_fp16_weights = has_fp16_weights
|
self.state.has_fp16_weights = has_fp16_weights
|
||||||
|
self.state.memory_efficient_backward = memory_efficient_backward
|
||||||
if threshold > 0.0 and not has_fp16_weights:
|
if threshold > 0.0 and not has_fp16_weights:
|
||||||
self.state.use_pool = True
|
self.state.use_pool = True
|
||||||
|
|
||||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
|
self.weight = Int8Params(
|
||||||
|
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
|
||||||
|
)
|
||||||
|
|
||||||
def init_8bit_state(self):
|
def init_8bit_state(self):
|
||||||
self.state.CB = self.weight.CB
|
self.state.CB = self.weight.CB
|
||||||
|
@ -255,11 +259,16 @@ class Linear8bitLt(nn.Linear):
|
||||||
|
|
||||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||||
|
|
||||||
if not self.state.has_fp16_weights and self.state.CB is not None:
|
if not self.state.has_fp16_weights:
|
||||||
|
if not self.state.memory_efficient_backward and self.state.CB is not None:
|
||||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||||
# we no longer need the row-major weight
|
# we no longer need the row-major weight
|
||||||
del self.state.CB
|
del self.state.CB
|
||||||
self.weight.data = self.state.CxB
|
self.weight.data = self.state.CxB
|
||||||
|
elif self.state.memory_efficient_backward and self.state.CxB is not None:
|
||||||
|
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
|
||||||
|
# Thus, we delete CxB from the state.
|
||||||
|
del self.state.CxB
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
|
@ -253,7 +253,7 @@ for c in req_grad:
|
||||||
|
|
||||||
transpose = [(False, True), (False, False)]
|
transpose = [(False, True), (False, False)]
|
||||||
str_transpose = ["NT", "NN"]
|
str_transpose = ["NT", "NN"]
|
||||||
dtype = [torch.float16]
|
dtype = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
has_fp16_weights = [True, False]
|
has_fp16_weights = [True, False]
|
||||||
has_bias = [True, False]
|
has_bias = [True, False]
|
||||||
values = list(
|
values = list(
|
||||||
|
@ -354,7 +354,7 @@ def test_matmullt(
|
||||||
state.SCB,
|
state.SCB,
|
||||||
SCBt,
|
SCBt,
|
||||||
coo_tensorB,
|
coo_tensorB,
|
||||||
) = bnb.functional.double_quant(B2)
|
) = bnb.functional.double_quant(B2.to(torch.float16))
|
||||||
B2 = state.CB
|
B2 = state.CB
|
||||||
|
|
||||||
if not transpose[0] and transpose[1]:
|
if not transpose[0] and transpose[1]:
|
||||||
|
@ -367,11 +367,14 @@ def test_matmullt(
|
||||||
if has_bias:
|
if has_bias:
|
||||||
out_torch += bias
|
out_torch += bias
|
||||||
|
|
||||||
|
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
|
||||||
|
|
||||||
n = out_bnb.numel()
|
n = out_bnb.numel()
|
||||||
err = torch.abs(out_bnb - out_torch).mean().item()
|
err = torch.abs(out_bnb - out_torch).mean().item()
|
||||||
# print(f'abs error {err:.4f}')
|
# print(f'abs error {err:.4f}')
|
||||||
|
|
||||||
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
||||||
assert (idx == 0).sum().item() <= n * 0.0175
|
assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021)
|
||||||
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
|
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
|
||||||
assert (idx == 0).sum().item() <= n * 0.001
|
assert (idx == 0).sum().item() <= n * 0.001
|
||||||
|
|
||||||
|
|
|
@ -14,13 +14,15 @@ class MockArgs(object):
|
||||||
|
|
||||||
|
|
||||||
class MLP8bit(torch.nn.Module):
|
class MLP8bit(torch.nn.Module):
|
||||||
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
|
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
|
||||||
super(MLP8bit, self).__init__()
|
super(MLP8bit, self).__init__()
|
||||||
self.fc1 = bnb.nn.Linear8bitLt(
|
self.fc1 = bnb.nn.Linear8bitLt(
|
||||||
dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
|
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
|
||||||
|
threshold=threshold
|
||||||
)
|
)
|
||||||
self.fc2 = bnb.nn.Linear8bitLt(
|
self.fc2 = bnb.nn.Linear8bitLt(
|
||||||
dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
|
dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
|
||||||
|
threshold=threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("threshold", values, ids=names)
|
@pytest.mark.parametrize("threshold", values, ids=names)
|
||||||
def test_linear8bitlt_no_fp16_weights(threshold):
|
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
|
||||||
|
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
||||||
l1 = (
|
l1 = (
|
||||||
bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
|
bnb.nn.Linear8bitLt(
|
||||||
|
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
||||||
|
)
|
||||||
.cuda()
|
.cuda()
|
||||||
.half()
|
.half()
|
||||||
)
|
)
|
||||||
|
@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
||||||
assert mlp.fc2.weight.dtype == torch.int8
|
assert mlp.fc2.weight.dtype == torch.int8
|
||||||
|
|
||||||
mlp = (
|
mlp = (
|
||||||
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
|
MLP8bit(
|
||||||
|
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
||||||
|
)
|
||||||
.half()
|
.half()
|
||||||
.to("cuda")
|
.to("cuda")
|
||||||
)
|
)
|
||||||
|
@ -531,11 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
||||||
assert mlp.fc1.weight.device.type == "cuda"
|
assert mlp.fc1.weight.device.type == "cuda"
|
||||||
assert mlp.fc2.weight.device.type == "cuda"
|
assert mlp.fc2.weight.device.type == "cuda"
|
||||||
|
|
||||||
mlp = (
|
mlp = MLP8bit(
|
||||||
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
|
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
||||||
.to(torch.float16)
|
|
||||||
.to("cuda")
|
|
||||||
)
|
)
|
||||||
|
w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization,
|
||||||
|
mlp = mlp.cuda().half() # and this line triggers quantization
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
||||||
|
@ -545,11 +552,28 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
||||||
assert mlp.fc1.state.idx is not None
|
assert mlp.fc1.state.idx is not None
|
||||||
if threshold > 0:
|
if threshold > 0:
|
||||||
assert mlp.fc2.state.idx is not None
|
assert mlp.fc2.state.idx is not None
|
||||||
|
|
||||||
assert mlp.fc1.weight.dtype == torch.int8
|
assert mlp.fc1.weight.dtype == torch.int8
|
||||||
assert mlp.fc2.weight.dtype == torch.int8
|
assert mlp.fc2.weight.dtype == torch.int8
|
||||||
assert mlp.fc1.weight.device.type == "cuda"
|
assert mlp.fc1.weight.device.type == "cuda"
|
||||||
assert mlp.fc2.weight.device.type == "cuda"
|
assert mlp.fc2.weight.device.type == "cuda"
|
||||||
|
|
||||||
|
if memory_efficient_backward:
|
||||||
|
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
|
||||||
|
o1 = mlp(b1)
|
||||||
|
assert o1.dtype == torch.float16
|
||||||
|
assert o1.requires_grad
|
||||||
|
grad_proj = torch.randn_like(o1)
|
||||||
|
|
||||||
|
mlp.zero_grad()
|
||||||
|
(o1 * grad_proj).sum().backward()
|
||||||
|
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
|
||||||
|
scale = grad_ref.abs().mean()
|
||||||
|
|
||||||
|
torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
|
||||||
|
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
|
||||||
|
assert (idx == 0).sum().item() <= b1.numel() * 0.005
|
||||||
|
|
||||||
|
|
||||||
def test_linear8bitlt_fp32_bias():
|
def test_linear8bitlt_fp32_bias():
|
||||||
# casts model to fp16 -> int8 automatically
|
# casts model to fp16 -> int8 automatically
|
||||||
|
|
Loading…
Reference in New Issue
Block a user