diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp
index 303e8ed..2081e68 100644
--- a/csrc/cpu_ops.cpp
+++ b/csrc/cpu_ops.cpp
@@ -30,11 +30,12 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
     // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
     for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
     {
-      pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * thread_wave_size);
+      long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
+      pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
 
-      struct quantize_block_args **args = (quantize_block_args **) malloc(thread_wave_size * sizeof(quantize_block_args *));
+      struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));
 
-      for(long long i = 0; i < thread_wave_size; i++)
+      for(long long i = 0; i < valid_chunks; i++)
           args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
 
       int chunks_processed = 0;
@@ -56,14 +57,14 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
 
           pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
           chunks_processed += 1;
-          if(chunks_processed == thread_wave_size){ break; }
+          if(chunks_processed == valid_chunks){ break; }
       }
 
-      for (int i = 0; i < thread_wave_size; i++)
+      for (int i = 0; i < valid_chunks; i++)
           int err = pthread_join(threads[i], NULL);
       
       free(threads);
-      for (int i = 0; i < thread_wave_size; i++)
+      for (int i = 0; i < valid_chunks; i++)
           free(args[i]);
       free(args);
 
diff --git a/tests/test_functional.py b/tests/test_functional.py
index d07affe..fcfdc72 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -2133,18 +2133,18 @@ def test_blockwise_cpu_large():
     reldiffs = []
     batch = 128
     seq = 128
-    hidden = 14336
-    for blocksize in [4096, 16384]:
-        for i in range(2):
-            A1 = torch.randn(batch, seq, hidden, device='cpu')
-            t0 = time.time()
-            C, S = F.quantize_blockwise(A1, blocksize=blocksize)
-            A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
-            print(time.time() - t0)
-            diff = torch.abs(A1 - A2)
-            reldiff = diff / torch.abs(A1 + 1e-8)
-            diffs.append(diff.mean().item())
-            reldiffs.append(reldiff.mean().item())
-            assert diffs[-1] < 0.011
-        # print(sum(diffs)/len(diffs))
-        # print(sum(reldiffs)/len(reldiffs))
+    for hidden in [128, 14336]:
+        for blocksize in [4096, 16384]:
+            for i in range(2):
+                A1 = torch.randn(batch, seq, hidden, device='cpu')
+                t0 = time.time()
+                C, S = F.quantize_blockwise(A1, blocksize=blocksize)
+                A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
+                print(time.time() - t0)
+                diff = torch.abs(A1 - A2)
+                reldiff = diff / torch.abs(A1 + 1e-8)
+                diffs.append(diff.mean().item())
+                reldiffs.append(reldiff.mean().item())
+                assert diffs[-1] < 0.011
+            # print(sum(diffs)/len(diffs))
+            # print(sum(reldiffs)/len(reldiffs))