Added more blocksizes for stochastic rounding; fixed dequant blocksize.
This commit is contained in:
parent
2dfa3ce16d
commit
2489d819c5
|
@ -413,10 +413,10 @@ class MatMulFP8(torch.autograd.Function):
|
||||||
# 2. MatmulnN
|
# 2. MatmulnN
|
||||||
|
|
||||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024)
|
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024)
|
||||||
fp8A = F.dequantize_blockwise(cA, state).to(A.dtype)
|
fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype)
|
||||||
|
|
||||||
cB, state = F.quantize_blockwise(B, code=fw_code, blocksize=1024)
|
cB, state = F.quantize_blockwise(B, code=fw_code, blocksize=1024)
|
||||||
fp8B = F.dequantize_blockwise(cB, state).to(B.dtype)
|
fp8B = F.dequantize_blockwise(cB, state, blocksize=1024).to(B.dtype)
|
||||||
|
|
||||||
output = torch.matmul(fp8A, fp8B)
|
output = torch.matmul(fp8A, fp8B)
|
||||||
|
|
||||||
|
@ -443,7 +443,7 @@ class MatMulFP8(torch.autograd.Function):
|
||||||
grad_A, grad_B = None, None
|
grad_A, grad_B = None, None
|
||||||
|
|
||||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=1024)
|
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=1024)
|
||||||
fp8out = F.dequantize_blockwise(cgrad_out, state).to(grad_output.dtype)
|
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=1024).to(grad_output.dtype)
|
||||||
|
|
||||||
# Cast grad_output to fp16
|
# Cast grad_output to fp16
|
||||||
if len(grad_output.shape) == 3:
|
if len(grad_output.shape) == 3:
|
||||||
|
|
|
@ -508,13 +508,12 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
||||||
code = code.to(A.device)
|
code = code.to(A.device)
|
||||||
if rand is not None:
|
if rand is not None:
|
||||||
is_on_gpu([code, A, out, absmax, rand])
|
is_on_gpu([code, A, out, absmax, rand])
|
||||||
assert blocksize==4096
|
|
||||||
assert rand.numel() >= 1024
|
assert rand.numel() >= 1024
|
||||||
rand_offset = random.randint(0, 1023)
|
rand_offset = random.randint(0, 1023)
|
||||||
if A.dtype == torch.float32:
|
if A.dtype == torch.float32:
|
||||||
lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), cblocksize, ct.c_int(A.numel()))
|
||||||
elif A.dtype == torch.float16:
|
elif A.dtype == torch.float16:
|
||||||
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), cblocksize, ct.c_int(A.numel()))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -2797,16 +2797,28 @@ template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half
|
||||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<half, 2048, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<float, 2048, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<half, 1024, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<float, 1024, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<half, 512, 2, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<float, 512, 2, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<half, 256, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<half, 256, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<float, 256, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<float, 256, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<half, 256, 2, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<float, 256, 2, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<half, 128, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<half, 128, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<float, 128, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<float, 128, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<half, 128, 2, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<float, 128, 2, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<half, 64, 1, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<half, 64, 1, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
template __global__ void kQuantizeBlockwise<float, 64, 1, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
template __global__ void kQuantizeBlockwise<float, 64, 1, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<half, 64, 1, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template __global__ void kQuantizeBlockwise<float, 64, 1, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
|
||||||
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||||
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||||
|
|
14
csrc/ops.cu
14
csrc/ops.cu
|
@ -54,23 +54,21 @@ template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A,
|
||||||
{
|
{
|
||||||
int num_blocks = n/blocksize;
|
int num_blocks = n/blocksize;
|
||||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||||
if(STOCHASTIC == 1)
|
|
||||||
assert(blocksize == 4096);
|
|
||||||
|
|
||||||
if(blocksize == 4096)
|
if(blocksize == 4096)
|
||||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||||
else if(blocksize == 2048)
|
else if(blocksize == 2048)
|
||||||
kQuantizeBlockwise<T, 2048, 4, 0><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
|
kQuantizeBlockwise<T, 2048, 4, STOCHASTIC><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||||
else if(blocksize == 1024)
|
else if(blocksize == 1024)
|
||||||
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
kQuantizeBlockwise<T, 1024, 4, STOCHASTIC><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||||
else if(blocksize == 512)
|
else if(blocksize == 512)
|
||||||
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
kQuantizeBlockwise<T, 512, 2, STOCHASTIC><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||||
else if(blocksize == 256)
|
else if(blocksize == 256)
|
||||||
kQuantizeBlockwise<T, 256, 2, 0><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
|
kQuantizeBlockwise<T, 256, 2, STOCHASTIC><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||||
else if(blocksize == 128)
|
else if(blocksize == 128)
|
||||||
kQuantizeBlockwise<T, 128, 2, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
kQuantizeBlockwise<T, 128, 2, STOCHASTIC><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||||
else if(blocksize == 64)
|
else if(blocksize == 64)
|
||||||
kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
kQuantizeBlockwise<T, 64, 1, STOCHASTIC><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||||
|
|
||||||
|
|
||||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
|
|
@ -77,8 +77,8 @@ void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){
|
||||||
|
|
||||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, blocksize, n); }
|
||||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, blocksize, n); }
|
||||||
|
|
||||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
|
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
|
||||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
||||||
|
@ -142,8 +142,8 @@ extern "C"
|
||||||
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||||
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
|
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, blocksize, n); }
|
||||||
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
|
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, blocksize, n); }
|
||||||
|
|
||||||
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||||
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||||
|
|
|
@ -188,21 +188,25 @@ def test_dynamic_blockwise_quantization():
|
||||||
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
|
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||||
|
|
||||||
|
|
||||||
def test_dynamic_blockwise_stochastic_quantization():
|
|
||||||
|
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
||||||
|
def test_dynamic_blockwise_stochastic_quantization(blocksize):
|
||||||
diffs = []
|
diffs = []
|
||||||
reldiffs = []
|
reldiffs = []
|
||||||
rand = torch.rand(1024).cuda()
|
rand = torch.rand(1024).cuda()
|
||||||
|
err = 0
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
A1 = torch.randn(1024, 1024, device="cuda")
|
A1 = torch.randn(1024, 1024, device="cuda")
|
||||||
C1, S1 = F.quantize_blockwise(A1, rand=rand)
|
C1, S1 = F.quantize_blockwise(A1, rand=rand, blocksize=blocksize)
|
||||||
C2, S2 = F.quantize_blockwise(A1)
|
C2, S2 = F.quantize_blockwise(A1, blocksize=blocksize)
|
||||||
|
A2 = F.dequantize_blockwise(C1, S1, blocksize=blocksize)
|
||||||
|
err += (A1-A2).abs().mean().item()/100
|
||||||
# a maximunm distance of quantized values of 1
|
# a maximunm distance of quantized values of 1
|
||||||
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
|
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
|
||||||
fraction_smaller = (C1 < C2).float().sum() / C1.numel()
|
fraction_smaller = (C1 < C2).float().sum() / C1.numel()
|
||||||
fraction_larger = (C1 > C2).float().sum() / C1.numel()
|
fraction_larger = (C1 > C2).float().sum() / C1.numel()
|
||||||
torch.testing.assert_allclose(
|
torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0)
|
||||||
fraction_larger, fraction_smaller, atol=0.01, rtol=0
|
assert err < 0.019
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user