diff --git a/include/reactphysics3d/collision/shapes/AABB.h b/include/reactphysics3d/collision/shapes/AABB.h
index 11dac195..a4d9f921 100644
--- a/include/reactphysics3d/collision/shapes/AABB.h
+++ b/include/reactphysics3d/collision/shapes/AABB.h
@@ -112,6 +112,9 @@ class AABB {
         /// Return true if the ray intersects the AABB
         bool testRayIntersect(const Vector3& rayOrigin, const Vector3& rayDirectionInv, decimal rayMaxFraction) const;
 
+        /// Compute the intersection of a ray and the AABB
+        bool raycast(const Ray& ray, Vector3& hitPoint) const;
+
         /// Apply a scale factor to the AABB
         void applyScale(const Vector3& scale);
 
@@ -311,6 +314,53 @@ RP3D_FORCE_INLINE bool AABB::testRayIntersect(const Vector3& rayOrigin, const Ve
     return tMax >= std::max(tMin, decimal(0.0));
 }
 
+// Compute the intersection of a ray and the AABB
+RP3D_FORCE_INLINE bool AABB::raycast(const Ray& ray, Vector3& hitPoint) const {
+
+    decimal tMin = decimal(0.0);
+    decimal tMax = DECIMAL_LARGEST;
+
+    const decimal epsilon = 0.00001;
+
+    const Vector3 rayDirection = ray.point2 - ray.point1;
+
+    // For all three slabs
+    for (int i=0; i < 3; i++) {
+
+        // If the ray is parallel to the slab
+        if (std::abs(rayDirection[i]) < epsilon) {
+
+            // If origin of the ray is not inside the slab, no hit
+            if (ray.point1[i] < mMinCoordinates[i] || ray.point1[i] > mMaxCoordinates[i]) return false;
+        }
+        else {
+
+            decimal rayDirectionInverse = decimal(1.0) / rayDirection[i];
+            decimal t1 = (mMinCoordinates[i] - ray.point1[i]) * rayDirectionInverse;
+            decimal t2 = (mMaxCoordinates[i] - ray.point1[i]) * rayDirectionInverse;
+
+            if (t1 > t2) {
+
+                // Swap t1 and t2
+                decimal tTemp = t2;
+                t2 = t1;
+                t1 = tTemp;
+            }
+            
+            tMin = std::max(tMin, t1);
+            tMax = std::min(tMax, t2);
+            
+            // Exit with no collision 
+            if (tMin > tMax) return false;
+        }
+    }
+
+    // Compute the hit point
+    hitPoint = ray.point1 + tMin * rayDirection;
+
+    return true;
+}
+
 }
 
 #endif
diff --git a/include/reactphysics3d/collision/shapes/HeightFieldShape.h b/include/reactphysics3d/collision/shapes/HeightFieldShape.h
index efa2625e..f706fc8d 100644
--- a/include/reactphysics3d/collision/shapes/HeightFieldShape.h
+++ b/include/reactphysics3d/collision/shapes/HeightFieldShape.h
@@ -102,6 +102,10 @@ class HeightFieldShape : public ConcaveShape {
                          HalfEdgeStructure& triangleHalfEdgeStructure, int upAxis = 1, decimal integerHeightScale = 1.0f,
                          const Vector3& scaling = Vector3(1,1,1));
 
+        /// Raycast a single triangle of the height-field
+        bool raycastTriangle(const Ray& ray, const Vector3& p1, const Vector3& p2, const Vector3& p3, uint shapeId,
+                             Collider *collider, RaycastInfo& raycastInfo, decimal &smallestHitFraction, MemoryAllocator& allocator) const;
+
         /// Raycast method with feedback information
         virtual bool raycast(const Ray& ray, RaycastInfo& raycastInfo, Collider* collider, MemoryAllocator& allocator) const override;
 
@@ -125,6 +129,9 @@ class HeightFieldShape : public ConcaveShape {
         /// Compute the shape Id for a given triangle
         uint computeTriangleShapeId(uint iIndex, uint jIndex, uint secondTriangleIncrement) const;
 
+        /// Compute the first grid cell of the heightfield intersected by a ray
+        bool computeEnteringRayGridCoordinates(const Ray& ray, int& i, int& j, Vector3& outHitPoint) const;
+        
         /// Destructor
         virtual ~HeightFieldShape() override = default;
 
diff --git a/src/collision/shapes/HeightFieldShape.cpp b/src/collision/shapes/HeightFieldShape.cpp
index f3693fc1..9ac5f03b 100644
--- a/src/collision/shapes/HeightFieldShape.cpp
+++ b/src/collision/shapes/HeightFieldShape.cpp
@@ -27,6 +27,7 @@
 #include <reactphysics3d/collision/shapes/HeightFieldShape.h>
 #include <reactphysics3d/collision/RaycastInfo.h>
 #include <reactphysics3d/utils/Profiler.h>
+#include <iostream>
 
 using namespace reactphysics3d;
 
@@ -232,24 +233,98 @@ bool HeightFieldShape::raycast(const Ray& ray, RaycastInfo& raycastInfo, Collide
 
     RP3D_PROFILE("HeightFieldShape::raycast()", mProfiler);
 
-    // Compute the AABB for the ray
-    const Vector3 rayEnd = ray.point1 + ray.maxFraction * (ray.point2 - ray.point1);
-    const AABB rayAABB(Vector3::min(ray.point1, rayEnd), Vector3::max(ray.point1, rayEnd));
+    // Apply the concave mesh inverse scale factor because the mesh is stored without scaling
+    // inside the dynamic AABB tree
+    const Vector3 inverseScale(decimal(1.0) / mScale.x, decimal(1.0) / mScale.y, decimal(1.0) / mScale.z);
+    Ray scaledRay(ray.point1 * inverseScale, ray.point2 * inverseScale, ray.maxFraction);
 
-    // Compute the triangles overlapping with the ray AABB
-    Array<Vector3> triangleVertices(allocator, 64);
-    Array<Vector3> triangleVerticesNormals(allocator, 64);
-    Array<uint> shapeIds(allocator, 64);
-    computeOverlappingTriangles(rayAABB, triangleVertices, triangleVerticesNormals, shapeIds, allocator);
+    // Compute the grid coordinates where the ray is entering the AABB of the height field
+    int i, j;
+    Vector3 outHitGridPoint;
+    bool isIntersecting = computeEnteringRayGridCoordinates(scaledRay, i, j, outHitGridPoint);
+    assert(isIntersecting);
 
-    assert(triangleVertices.size() == triangleVerticesNormals.size());
-    assert(shapeIds.size() == triangleVertices.size() / 3);
-    assert(triangleVertices.size() % 3 == 0);
-    assert(triangleVerticesNormals.size() % 3 == 0);
+    const int nbCellsI = mNbColumns - 1;
+    const int nbCellsJ = mNbRows - 1;
+
+    const Vector3 aabbSize = mAABB.getExtent();
+
+    const Vector3 rayDirection = scaledRay.point2 - scaledRay.point1;
+
+    int stepI, stepJ;
+    decimal tMaxI, tMaxJ, nextI, nextJ, tDeltaI, tDeltaJ, sizeI, sizeJ;
+
+    switch(mUpAxis) {
+        case 0 : stepI = rayDirection.y > 0 ? 1 : (rayDirection.y < 0 ? -1 : 0);
+                 stepJ = rayDirection.z > 0 ? 1 : (rayDirection.z < 0 ? -1 : 0);
+                 nextI = stepI >= 0 ? i + 1 : i;
+                 nextJ = stepJ >= 0 ? j + 1 : j;
+                 sizeI = aabbSize.y / nbCellsI;
+                 sizeJ = aabbSize.z / nbCellsJ;
+                 tMaxI = ((nextI * sizeI) - outHitGridPoint.y) / rayDirection.y;
+                 tMaxJ = ((nextJ * sizeJ) - outHitGridPoint.z) / rayDirection.z;
+                 tDeltaI = sizeI / std::abs(rayDirection.y);
+                 tDeltaJ = sizeJ / std::abs(rayDirection.z);
+                 break;
+        case 1 : stepI = rayDirection.x > 0 ? 1 : (rayDirection.x < 0 ? -1 : 0);
+                 stepJ = rayDirection.z > 0 ? 1 : (rayDirection.z < 0 ? -1 : 0);
+                 nextI = stepI >= 0 ? i + 1 : i;
+                 nextJ = stepJ >= 0 ? j + 1 : j;
+                 sizeI = aabbSize.x / nbCellsI;
+                 sizeJ = aabbSize.z / nbCellsJ;
+                 tMaxI = ((nextI * sizeI) - outHitGridPoint.x) / rayDirection.x;
+                 tMaxJ = ((nextJ * sizeJ) - outHitGridPoint.z) / rayDirection.z;
+                 tDeltaI = sizeI / std::abs(rayDirection.x);
+                 tDeltaJ = sizeJ / std::abs(rayDirection.z);
+                 break;
+        case 2 : stepI = rayDirection.x > 0 ? 1 : (rayDirection.x < 0 ? -1 : 0);
+                 stepJ = rayDirection.y > 0 ? 1 : (rayDirection.y < 0 ? -1 : 0);
+                 nextI = stepI >= 0 ? i + 1 : i;
+                 nextJ = stepJ >= 0 ? j + 1 : j;
+                 sizeI = aabbSize.x / nbCellsI;
+                 sizeJ = aabbSize.y / nbCellsJ;
+                 tMaxI = ((nextI * sizeI) - outHitGridPoint.x) / rayDirection.x;
+                 tMaxJ = ((nextJ * sizeJ) - outHitGridPoint.y) / rayDirection.y;
+                 tDeltaI = sizeI / std::abs(rayDirection.x);
+                 tDeltaJ = sizeJ / std::abs(rayDirection.y);
+                 break;
+    }
 
     bool isHit = false;
     decimal smallestHitFraction = ray.maxFraction;
 
+    while (i >= 0 && i < nbCellsI && j >= 0 && j < nbCellsJ) {
+
+        // TODO : Remove this
+        //std::cout << "Cell " << i << ", " << j << std::endl;
+
+       // Compute the four point of the current quad
+       const Vector3 p1 = getVertexAt(i, j);
+       const Vector3 p2 = getVertexAt(i, j + 1);
+       const Vector3 p3 = getVertexAt(i + 1, j);
+       const Vector3 p4 = getVertexAt(i + 1, j + 1);
+
+       // Raycast against the first triangle of the cell
+       uint shapeId = computeTriangleShapeId(i, j, 0);
+       isHit |= raycastTriangle(ray, p1, p2, p3, shapeId, collider, raycastInfo, smallestHitFraction, allocator);
+
+       // Raycast against the second triangle of the cell
+       shapeId = computeTriangleShapeId(i, j, 1);
+       isHit |= raycastTriangle(ray, p3, p2, p4, shapeId, collider, raycastInfo, smallestHitFraction, allocator);
+
+       if (stepI == 0 && stepJ == 0) break;
+
+       if (tMaxI < tMaxJ) {
+            tMaxI += tDeltaI;
+            i += stepI;
+        }
+        else {
+            tMaxJ += tDeltaJ;
+            j += stepJ;
+        }
+    }
+
+    /*
     // For each overlapping triangle
     const uint32 nbShapeIds = shapeIds.size();
     for (uint32 i=0; i < nbShapeIds; i++)
@@ -287,10 +362,108 @@ bool HeightFieldShape::raycast(const Ray& ray, RaycastInfo& raycastInfo, Collide
             isHit = true;
         }
     }
+    */
 
     return isHit;
 }
 
+// Raycast a single triangle of the height-field
+bool HeightFieldShape::raycastTriangle(const Ray& ray, const Vector3& p1, const Vector3& p2, const Vector3& p3, uint shapeId,
+                                       Collider* collider, RaycastInfo& raycastInfo, decimal& smallestHitFraction, MemoryAllocator& allocator) const {
+
+   // Generate the first triangle for the current grid rectangle
+   Vector3 triangleVertices[3] = {p1, p2, p3};
+
+   // Compute the triangle normal
+   Vector3 triangleNormal = (p2 - p1).cross(p3 - p1).getUnit();
+
+   // Use the triangle face normal as vertices normals (this is an aproximation. The correct
+   // solution would be to compute all the normals of the neighbor triangles and use their
+   // weighted average (with incident angle as weight) at the vertices. However, this solution
+   // seems too expensive (it requires to compute the normal of all neighbor triangles instead
+   // and compute the angle of incident edges with asin(). Maybe we could also precompute the
+   // vertices normal at the HeightFieldShape constructor but it will require extra memory to
+   // store them.
+   Vector3 triangleVerticesNormals[3] = {triangleNormal, triangleNormal, triangleNormal};
+
+    // Create a triangle collision shape
+    TriangleShape triangleShape(triangleVertices, triangleVerticesNormals, shapeId, mTriangleHalfEdgeStructure, allocator);
+    triangleShape.setRaycastTestType(getRaycastTestType());
+
+#ifdef IS_RP3D_PROFILING_ENABLED
+
+
+    // Set the profiler to the triangle shape
+    triangleShape.setProfiler(mProfiler);
+
+#endif
+
+    // Ray casting test against the collision shape
+    RaycastInfo triangleRaycastInfo;
+    bool isTriangleHit = triangleShape.raycast(ray, triangleRaycastInfo, collider, allocator);
+
+    // If the ray hit the collision shape
+    if (isTriangleHit && triangleRaycastInfo.hitFraction <= smallestHitFraction) {
+
+        assert(triangleRaycastInfo.hitFraction >= decimal(0.0));
+
+        raycastInfo.body = triangleRaycastInfo.body;
+        raycastInfo.collider = triangleRaycastInfo.collider;
+        raycastInfo.hitFraction = triangleRaycastInfo.hitFraction;
+        raycastInfo.worldPoint = triangleRaycastInfo.worldPoint;
+        raycastInfo.worldNormal = triangleRaycastInfo.worldNormal;
+        raycastInfo.meshSubpart = -1;
+        raycastInfo.triangleIndex = -1;
+
+        smallestHitFraction = triangleRaycastInfo.hitFraction;
+
+        return true;
+    }
+
+    return false;
+}
+
+// Compute the first grid cell of the heightfield intersected by a ray.
+/// This method returns true if the ray hit the AABB of the height field and false otherwise
+bool HeightFieldShape::computeEnteringRayGridCoordinates(const Ray& ray, int& i, int& j, Vector3& outHitGridPoint) const {
+    
+    decimal stepI, stepJ;
+    const Vector3 aabbSize = mAABB.getExtent();
+
+    const uint32 nbCellsI = mNbColumns - 1;
+    const uint32 nbCellsJ = mNbRows - 1;
+
+    if (mAABB.raycast(ray, outHitGridPoint)) {
+
+        // Map the hit point into the grid range [0, mNbColumns - 1], [0, mNbRows - 1]
+        outHitGridPoint -= mAABB.getMin();
+
+        switch(mUpAxis) {
+            case 0 : stepI = aabbSize.y / nbCellsI;
+                     stepJ = aabbSize.z / nbCellsJ;
+                     i = clamp(int(outHitGridPoint.y / stepI), 0, nbCellsI - 1);
+                     j = clamp(int(outHitGridPoint.z / stepJ), 0, nbCellsJ - 1);
+                     break;
+            case 1 : stepI = aabbSize.x / nbCellsI;
+                     stepJ = aabbSize.z / nbCellsJ;
+                     i = clamp(int(outHitGridPoint.x / stepI), 0, nbCellsI - 1);
+                     j = clamp(int(outHitGridPoint.z / stepJ), 0, nbCellsJ - 1);
+                     break;
+            case 2 : stepI = aabbSize.x / nbCellsI;
+                     stepJ = aabbSize.y / nbCellsJ;
+                     i = clamp(int(outHitGridPoint.x / stepI), 0, nbCellsI - 1);
+                     j = clamp(int(outHitGridPoint.y / stepJ), 0, nbCellsJ - 1);
+                     break;
+        }
+
+        assert(i >= 0 && i < nbCellsI);
+        assert(j >= 0 && j < nbCellsJ);
+        return true;
+    }
+    
+    return false;
+}
+
 // Return the vertex (local-coordinates) of the height field at a given (x,y) position
 Vector3 HeightFieldShape::getVertexAt(int x, int y) const {