diff --git a/src/components/BallAndSocketJointComponents.h b/src/components/BallAndSocketJointComponents.h index 13d58eab..3af8e569 100644 --- a/src/components/BallAndSocketJointComponents.h +++ b/src/components/BallAndSocketJointComponents.h @@ -128,16 +128,16 @@ class BallAndSocketJointComponents : public Components { void setJoint(Entity jointEntity, BallAndSocketJoint* joint) const; /// Return the local anchor point of body 1 for a given joint - const Vector3& getLocalAnchoirPointBody1(Entity jointEntity) const; + const Vector3& getLocalAnchorPointBody1(Entity jointEntity) const; /// Set the local anchor point of body 1 for a given joint - void setLocalAnchoirPointBody1(Entity jointEntity, const Vector3& localAnchoirPointBody1); + void setLocalAnchorPointBody1(Entity jointEntity, const Vector3& localAnchoirPointBody1); /// Return the local anchor point of body 2 for a given joint - const Vector3& getLocalAnchoirPointBody2(Entity jointEntity) const; + const Vector3& getLocalAnchorPointBody2(Entity jointEntity) const; /// Set the local anchor point of body 2 for a given joint - void setLocalAnchoirPointBody2(Entity jointEntity, const Vector3& localAnchoirPointBody2); + void setLocalAnchorPointBody2(Entity jointEntity, const Vector3& localAnchoirPointBody2); /// Return the vector from center of body 1 to anchor point in world-space const Vector3& getR1World(Entity jointEntity) const; @@ -184,6 +184,7 @@ class BallAndSocketJointComponents : public Components { // -------------------- Friendship -------------------- // friend class BroadPhaseSystem; + friend class SolveBallAndSocketJointSystem; }; // Return a pointer to a given joint @@ -201,28 +202,28 @@ inline void BallAndSocketJointComponents::setJoint(Entity jointEntity, BallAndSo } // Return the local anchor point of body 1 for a given joint -inline const Vector3& BallAndSocketJointComponents::getLocalAnchoirPointBody1(Entity jointEntity) const { +inline const Vector3& BallAndSocketJointComponents::getLocalAnchorPointBody1(Entity jointEntity) const { assert(mMapEntityToComponentIndex.containsKey(jointEntity)); return mLocalAnchorPointBody1[mMapEntityToComponentIndex[jointEntity]]; } // Set the local anchor point of body 1 for a given joint -inline void BallAndSocketJointComponents::setLocalAnchoirPointBody1(Entity jointEntity, const Vector3& localAnchoirPointBody1) { +inline void BallAndSocketJointComponents::setLocalAnchorPointBody1(Entity jointEntity, const Vector3& localAnchoirPointBody1) { assert(mMapEntityToComponentIndex.containsKey(jointEntity)); mLocalAnchorPointBody1[mMapEntityToComponentIndex[jointEntity]] = localAnchoirPointBody1; } // Return the local anchor point of body 2 for a given joint -inline const Vector3& BallAndSocketJointComponents::getLocalAnchoirPointBody2(Entity jointEntity) const { +inline const Vector3& BallAndSocketJointComponents::getLocalAnchorPointBody2(Entity jointEntity) const { assert(mMapEntityToComponentIndex.containsKey(jointEntity)); return mLocalAnchorPointBody2[mMapEntityToComponentIndex[jointEntity]]; } // Set the local anchor point of body 2 for a given joint -inline void BallAndSocketJointComponents::setLocalAnchoirPointBody2(Entity jointEntity, const Vector3& localAnchoirPointBody2) { +inline void BallAndSocketJointComponents::setLocalAnchorPointBody2(Entity jointEntity, const Vector3& localAnchoirPointBody2) { assert(mMapEntityToComponentIndex.containsKey(jointEntity)); mLocalAnchorPointBody2[mMapEntityToComponentIndex[jointEntity]] = localAnchoirPointBody2; diff --git a/src/components/RigidBodyComponents.h b/src/components/RigidBodyComponents.h index 23052176..8ed50a09 100644 --- a/src/components/RigidBodyComponents.h +++ b/src/components/RigidBodyComponents.h @@ -346,6 +346,7 @@ class RigidBodyComponents : public Components { friend class DynamicsWorld; friend class ContactSolverSystem; + friend class SolveBallAndSocketJointSystem; friend class DynamicsSystem; friend class BallAndSocketJoint; friend class FixedJoint; diff --git a/src/constraint/BallAndSocketJoint.cpp b/src/constraint/BallAndSocketJoint.cpp index f329cbdc..c098289d 100644 --- a/src/constraint/BallAndSocketJoint.cpp +++ b/src/constraint/BallAndSocketJoint.cpp @@ -43,264 +43,15 @@ BallAndSocketJoint::BallAndSocketJoint(Entity entity, DynamicsWorld& world, cons Transform& body2Transform = mWorld.mTransformComponents.getTransform(jointInfo.body2->getEntity()); // Compute the local-space anchor point for each body - mWorld.mBallAndSocketJointsComponents.setLocalAnchoirPointBody1(entity, body1Transform.getInverse() * jointInfo.anchorPointWorldSpace); - mWorld.mBallAndSocketJointsComponents.setLocalAnchoirPointBody2(entity, body2Transform.getInverse() * jointInfo.anchorPointWorldSpace); + mWorld.mBallAndSocketJointsComponents.setLocalAnchorPointBody1(entity, body1Transform.getInverse() * jointInfo.anchorPointWorldSpace); + mWorld.mBallAndSocketJointsComponents.setLocalAnchorPointBody2(entity, body2Transform.getInverse() * jointInfo.anchorPointWorldSpace); } -// Initialize before solving the constraint -void BallAndSocketJoint::initBeforeSolve(const ConstraintSolverData& constraintSolverData) { - - // Get the bodies entities - Entity body1Entity = mWorld.mJointsComponents.getBody1Entity(mEntity); - Entity body2Entity = mWorld.mJointsComponents.getBody2Entity(mEntity); - - // TODO : Remove this and use compoents instead of pointers to bodies - RigidBody* body1 = static_cast(mWorld.mRigidBodyComponents.getRigidBody(body1Entity)); - RigidBody* body2 = static_cast(mWorld.mRigidBodyComponents.getRigidBody(body2Entity)); - - // Get the bodies center of mass and orientations - const Vector3& x1 = constraintSolverData.rigidBodyComponents.getCenterOfMassWorld(body1Entity); - const Vector3& x2 = constraintSolverData.rigidBodyComponents.getCenterOfMassWorld(body2Entity); - const Quaternion& orientationBody1 = body1->getTransform().getOrientation(); - const Quaternion& orientationBody2 = body2->getTransform().getOrientation(); - - // Get the inertia tensor of bodies - mWorld.mBallAndSocketJointsComponents.setI1(mEntity, body1->getInertiaTensorInverseWorld()); - mWorld.mBallAndSocketJointsComponents.setI2(mEntity, body2->getInertiaTensorInverseWorld()); - - // Compute the vector from body center to the anchor point in world-space - const Vector3 localAnchorPointBody1 = mWorld.mBallAndSocketJointsComponents.getLocalAnchoirPointBody1(mEntity); - const Vector3 localAnchorPointBody2 = mWorld.mBallAndSocketJointsComponents.getLocalAnchoirPointBody2(mEntity); - mWorld.mBallAndSocketJointsComponents.setR1World(mEntity, orientationBody1 * localAnchorPointBody1); - mWorld.mBallAndSocketJointsComponents.setR2World(mEntity, orientationBody2 * localAnchorPointBody2); - - // Compute the corresponding skew-symmetric matrices - const Vector3& r1World = mWorld.mBallAndSocketJointsComponents.getR1World(mEntity); - const Vector3& r2World = mWorld.mBallAndSocketJointsComponents.getR2World(mEntity); - Matrix3x3 skewSymmetricMatrixU1= Matrix3x3::computeSkewSymmetricMatrixForCrossProduct(r1World); - Matrix3x3 skewSymmetricMatrixU2= Matrix3x3::computeSkewSymmetricMatrixForCrossProduct(r2World); - - // Compute the matrix K=JM^-1J^t (3x3 matrix) - const decimal body1MassInverse = constraintSolverData.rigidBodyComponents.getMassInverse(body1->getEntity()); - const decimal body2MassInverse = constraintSolverData.rigidBodyComponents.getMassInverse(body2->getEntity()); - const decimal inverseMassBodies = body1MassInverse + body2MassInverse; - const Matrix3x3& i1 = mWorld.mBallAndSocketJointsComponents.getI1(mEntity); - const Matrix3x3& i2 = mWorld.mBallAndSocketJointsComponents.getI2(mEntity); - Matrix3x3 massMatrix = Matrix3x3(inverseMassBodies, 0, 0, - 0, inverseMassBodies, 0, - 0, 0, inverseMassBodies) + - skewSymmetricMatrixU1 * i1 * skewSymmetricMatrixU1.getTranspose() + - skewSymmetricMatrixU2 * i2 * skewSymmetricMatrixU2.getTranspose(); - - // Compute the inverse mass matrix K^-1 - Matrix3x3& inverseMassMatrix = mWorld.mBallAndSocketJointsComponents.getInverseMassMatrix(mEntity); - inverseMassMatrix.setToZero(); - if (mWorld.mRigidBodyComponents.getBodyType(body1Entity) == BodyType::DYNAMIC || - mWorld.mRigidBodyComponents.getBodyType(body2Entity) == BodyType::DYNAMIC) { - mWorld.mBallAndSocketJointsComponents.setInverseMassMatrix(mEntity, massMatrix.getInverse()); - } - - // Compute the bias "b" of the constraint - Vector3& biasVector = mWorld.mBallAndSocketJointsComponents.getBiasVector(mEntity); - biasVector.setToZero(); - if (mWorld.mJointsComponents.getPositionCorrectionTechnique(mEntity) == JointsPositionCorrectionTechnique::BAUMGARTE_JOINTS) { - decimal biasFactor = (BETA / constraintSolverData.timeStep); - mWorld.mBallAndSocketJointsComponents.setBiasVector(mEntity, biasFactor * (x2 + r2World - x1 - r1World)); - } - - // If warm-starting is not enabled - if (!constraintSolverData.isWarmStartingActive) { - - // Reset the accumulated impulse - Vector3& impulse = mWorld.mBallAndSocketJointsComponents.getImpulse(mEntity); - impulse.setToZero(); - } -} - -// Warm start the constraint (apply the previous impulse at the beginning of the step) -void BallAndSocketJoint::warmstart(const ConstraintSolverData& constraintSolverData) { - - Entity body1Entity = mWorld.mJointsComponents.getBody1Entity(mEntity); - Entity body2Entity = mWorld.mJointsComponents.getBody2Entity(mEntity); - - uint32 dynamicsComponentIndexBody1 = constraintSolverData.rigidBodyComponents.getEntityIndex(body1Entity); - uint32 dynamicsComponentIndexBody2 = constraintSolverData.rigidBodyComponents.getEntityIndex(body2Entity); - - // Get the velocities - Vector3& v1 = constraintSolverData.rigidBodyComponents.mConstrainedLinearVelocities[dynamicsComponentIndexBody1]; - Vector3& v2 = constraintSolverData.rigidBodyComponents.mConstrainedLinearVelocities[dynamicsComponentIndexBody2]; - Vector3& w1 = constraintSolverData.rigidBodyComponents.mConstrainedAngularVelocities[dynamicsComponentIndexBody1]; - Vector3& w2 = constraintSolverData.rigidBodyComponents.mConstrainedAngularVelocities[dynamicsComponentIndexBody2]; - - const Vector3& r1World = mWorld.mBallAndSocketJointsComponents.getR1World(mEntity); - const Vector3& r2World = mWorld.mBallAndSocketJointsComponents.getR2World(mEntity); - - const Matrix3x3& i1 = mWorld.mBallAndSocketJointsComponents.getI1(mEntity); - const Matrix3x3& i2 = mWorld.mBallAndSocketJointsComponents.getI2(mEntity); - - // Compute the impulse P=J^T * lambda for the body 1 - const Vector3& impulse = mWorld.mBallAndSocketJointsComponents.getImpulse(mEntity); - const Vector3 linearImpulseBody1 = -impulse; - const Vector3 angularImpulseBody1 = impulse.cross(r1World); - - // Apply the impulse to the body 1 - v1 += constraintSolverData.rigidBodyComponents.getMassInverse(body1Entity) * linearImpulseBody1; - w1 += i1 * angularImpulseBody1; - - // Compute the impulse P=J^T * lambda for the body 2 - const Vector3 angularImpulseBody2 = -impulse.cross(r2World); - - // Apply the impulse to the body to the body 2 - v2 += constraintSolverData.rigidBodyComponents.getMassInverse(body2Entity) * impulse; - w2 += i2 * angularImpulseBody2; -} - -// Solve the velocity constraint -void BallAndSocketJoint::solveVelocityConstraint(const ConstraintSolverData& constraintSolverData) { - - Entity body1Entity = mWorld.mJointsComponents.getBody1Entity(mEntity); - Entity body2Entity = mWorld.mJointsComponents.getBody2Entity(mEntity); - - uint32 dynamicsComponentIndexBody1 = constraintSolverData.rigidBodyComponents.getEntityIndex(body1Entity); - uint32 dynamicsComponentIndexBody2 = constraintSolverData.rigidBodyComponents.getEntityIndex(body2Entity); - - // Get the velocities - Vector3& v1 = constraintSolverData.rigidBodyComponents.mConstrainedLinearVelocities[dynamicsComponentIndexBody1]; - Vector3& v2 = constraintSolverData.rigidBodyComponents.mConstrainedLinearVelocities[dynamicsComponentIndexBody2]; - Vector3& w1 = constraintSolverData.rigidBodyComponents.mConstrainedAngularVelocities[dynamicsComponentIndexBody1]; - Vector3& w2 = constraintSolverData.rigidBodyComponents.mConstrainedAngularVelocities[dynamicsComponentIndexBody2]; - - const Vector3& r1World = mWorld.mBallAndSocketJointsComponents.getR1World(mEntity); - const Vector3& r2World = mWorld.mBallAndSocketJointsComponents.getR2World(mEntity); - - const Matrix3x3& i1 = mWorld.mBallAndSocketJointsComponents.getI1(mEntity); - const Matrix3x3& i2 = mWorld.mBallAndSocketJointsComponents.getI2(mEntity); - - const Matrix3x3& inverseMassMatrix = mWorld.mBallAndSocketJointsComponents.getInverseMassMatrix(mEntity); - const Vector3& biasVector = mWorld.mBallAndSocketJointsComponents.getBiasVector(mEntity); - - // Compute J*v - const Vector3 Jv = v2 + w2.cross(r2World) - v1 - w1.cross(r1World); - - // Compute the Lagrange multiplier lambda - const Vector3 deltaLambda = inverseMassMatrix * (-Jv - biasVector); - mWorld.mBallAndSocketJointsComponents.setImpulse(mEntity, mWorld.mBallAndSocketJointsComponents.getImpulse(mEntity) + deltaLambda); - - // Compute the impulse P=J^T * lambda for the body 1 - const Vector3 linearImpulseBody1 = -deltaLambda; - const Vector3 angularImpulseBody1 = deltaLambda.cross(r1World); - - // Apply the impulse to the body 1 - v1 += constraintSolverData.rigidBodyComponents.getMassInverse(body1Entity) * linearImpulseBody1; - w1 += i1 * angularImpulseBody1; - - // Compute the impulse P=J^T * lambda for the body 2 - const Vector3 angularImpulseBody2 = -deltaLambda.cross(r2World); - - // Apply the impulse to the body 2 - v2 += constraintSolverData.rigidBodyComponents.getMassInverse(body2Entity) * deltaLambda; - w2 += i2 * angularImpulseBody2; -} - -// Solve the position constraint (for position error correction) -void BallAndSocketJoint::solvePositionConstraint(const ConstraintSolverData& constraintSolverData) { - - Entity body1Entity = mWorld.mJointsComponents.getBody1Entity(mEntity); - Entity body2Entity = mWorld.mJointsComponents.getBody2Entity(mEntity); - - // TODO : Remove this and use compoents instead of pointers to bodies - RigidBody* body1 = static_cast(mWorld.mRigidBodyComponents.getRigidBody(body1Entity)); - RigidBody* body2 = static_cast(mWorld.mRigidBodyComponents.getRigidBody(body2Entity)); - - // If the error position correction technique is not the non-linear-gauss-seidel, we do - // do not execute this method - if (mWorld.mJointsComponents.getPositionCorrectionTechnique(mEntity) != JointsPositionCorrectionTechnique::NON_LINEAR_GAUSS_SEIDEL) return; - - // Get the bodies center of mass and orientations - Vector3 x1 = constraintSolverData.rigidBodyComponents.getConstrainedPosition(body1Entity); - Vector3 x2 = constraintSolverData.rigidBodyComponents.getConstrainedPosition(body2Entity); - Quaternion q1 = constraintSolverData.rigidBodyComponents.getConstrainedOrientation(body1Entity); - Quaternion q2 = constraintSolverData.rigidBodyComponents.getConstrainedOrientation(body2Entity); - - // Get the inverse mass and inverse inertia tensors of the bodies - const decimal inverseMassBody1 = constraintSolverData.rigidBodyComponents.getMassInverse(body1Entity); - const decimal inverseMassBody2 = constraintSolverData.rigidBodyComponents.getMassInverse(body2Entity); - - const Vector3& r1World = mWorld.mBallAndSocketJointsComponents.getR1World(mEntity); - const Vector3& r2World = mWorld.mBallAndSocketJointsComponents.getR2World(mEntity); - - const Matrix3x3& i1 = mWorld.mBallAndSocketJointsComponents.getI1(mEntity); - const Matrix3x3& i2 = mWorld.mBallAndSocketJointsComponents.getI2(mEntity); - - // Recompute the inverse inertia tensors - mWorld.mBallAndSocketJointsComponents.setI1(mEntity, body1->getInertiaTensorInverseWorld()); - mWorld.mBallAndSocketJointsComponents.setI2(mEntity, body2->getInertiaTensorInverseWorld()); - - // Compute the vector from body center to the anchor point in world-space - mWorld.mBallAndSocketJointsComponents.setR1World(mEntity, q1 * mWorld.mBallAndSocketJointsComponents.getLocalAnchoirPointBody1(mEntity)); - mWorld.mBallAndSocketJointsComponents.setR2World(mEntity, q2 * mWorld.mBallAndSocketJointsComponents.getLocalAnchoirPointBody2(mEntity)); - - // Compute the corresponding skew-symmetric matrices - Matrix3x3 skewSymmetricMatrixU1= Matrix3x3::computeSkewSymmetricMatrixForCrossProduct(r1World); - Matrix3x3 skewSymmetricMatrixU2= Matrix3x3::computeSkewSymmetricMatrixForCrossProduct(r2World); - - // Recompute the inverse mass matrix K=J^TM^-1J of of the 3 translation constraints - decimal inverseMassBodies = inverseMassBody1 + inverseMassBody2; - Matrix3x3 massMatrix = Matrix3x3(inverseMassBodies, 0, 0, - 0, inverseMassBodies, 0, - 0, 0, inverseMassBodies) + - skewSymmetricMatrixU1 * i1 * skewSymmetricMatrixU1.getTranspose() + - skewSymmetricMatrixU2 * i2 * skewSymmetricMatrixU2.getTranspose(); - Matrix3x3& inverseMassMatrix = mWorld.mBallAndSocketJointsComponents.getInverseMassMatrix(mEntity); - inverseMassMatrix.setToZero(); - if (mWorld.mRigidBodyComponents.getBodyType(body1Entity) == BodyType::DYNAMIC || - mWorld.mRigidBodyComponents.getBodyType(body2Entity) == BodyType::DYNAMIC) { - mWorld.mBallAndSocketJointsComponents.setInverseMassMatrix(mEntity, massMatrix.getInverse()); - } - - // Compute the constraint error (value of the C(x) function) - const Vector3 constraintError = (x2 + r2World - x1 - r1World); - - // Compute the Lagrange multiplier lambda - // TODO : Do not solve the system by computing the inverse each time and multiplying with the - // right-hand side vector but instead use a method to directly solve the linear system. - const Vector3 lambda = inverseMassMatrix * (-constraintError); - - // Compute the impulse of body 1 - const Vector3 linearImpulseBody1 = -lambda; - const Vector3 angularImpulseBody1 = lambda.cross(r1World); - - // Compute the pseudo velocity of body 1 - const Vector3 v1 = inverseMassBody1 * linearImpulseBody1; - const Vector3 w1 = i1 * angularImpulseBody1; - - // Update the body center of mass and orientation of body 1 - x1 += v1; - q1 += Quaternion(0, w1) * q1 * decimal(0.5); - q1.normalize(); - - // Compute the impulse of body 2 - const Vector3 angularImpulseBody2 = -lambda.cross(r2World); - - // Compute the pseudo velocity of body 2 - const Vector3 v2 = inverseMassBody2 * lambda; - const Vector3 w2 = i2 * angularImpulseBody2; - - // Update the body position/orientation of body 2 - x2 += v2; - q2 += Quaternion(0, w2) * q2 * decimal(0.5); - q2.normalize(); - - constraintSolverData.rigidBodyComponents.setConstrainedPosition(body1Entity, x1); - constraintSolverData.rigidBodyComponents.setConstrainedPosition(body2Entity, x2); - constraintSolverData.rigidBodyComponents.setConstrainedOrientation(body1Entity, q1); - constraintSolverData.rigidBodyComponents.setConstrainedOrientation(body2Entity, q2); -} // Return a string representation std::string BallAndSocketJoint::to_string() const { - return "BallAndSocketJoint{ localAnchorPointBody1=" + mWorld.mBallAndSocketJointsComponents.getLocalAnchoirPointBody1(mEntity).to_string() + - ", localAnchorPointBody2=" + mWorld.mBallAndSocketJointsComponents.getLocalAnchoirPointBody2(mEntity).to_string() + "}"; + return "BallAndSocketJoint{ localAnchorPointBody1=" + mWorld.mBallAndSocketJointsComponents.getLocalAnchorPointBody1(mEntity).to_string() + + ", localAnchorPointBody2=" + mWorld.mBallAndSocketJointsComponents.getLocalAnchorPointBody2(mEntity).to_string() + "}"; } diff --git a/src/constraint/BallAndSocketJoint.h b/src/constraint/BallAndSocketJoint.h index 533b57c2..4323ef1c 100644 --- a/src/constraint/BallAndSocketJoint.h +++ b/src/constraint/BallAndSocketJoint.h @@ -82,17 +82,6 @@ class BallAndSocketJoint : public Joint { /// Return the number of bytes used by the joint virtual size_t getSizeInBytes() const override; - /// Initialize before solving the constraint - virtual void initBeforeSolve(const ConstraintSolverData& constraintSolverData) override; - - /// Warm start the constraint (apply the previous impulse at the beginning of the step) - virtual void warmstart(const ConstraintSolverData& constraintSolverData) override; - - /// Solve the velocity constraint - virtual void solveVelocityConstraint(const ConstraintSolverData& constraintSolverData) override; - - /// Solve the position constraint (for position error correction) - virtual void solvePositionConstraint(const ConstraintSolverData& constraintSolverData) override; public : @@ -112,6 +101,31 @@ class BallAndSocketJoint : public Joint { /// Deleted assignment operator BallAndSocketJoint& operator=(const BallAndSocketJoint& constraint) = delete; + + /// Initialize before solving the constraint + // TODO : Delete this + virtual void initBeforeSolve(const ConstraintSolverData& constraintSolverData) override { + + } + + /// Warm start the constraint (apply the previous impulse at the beginning of the step) + // TODO : Delete this + virtual void warmstart(const ConstraintSolverData& constraintSolverData) override { + + } + + /// Solve the velocity constraint + // TODO : Delete this + virtual void solveVelocityConstraint(const ConstraintSolverData& constraintSolverData) override { + + } + + /// Solve the position constraint (for position error correction) + // TODO : Delete this + virtual void solvePositionConstraint(const ConstraintSolverData& constraintSolverData) override { + + } + }; // Return the number of bytes used by the joint diff --git a/src/constraint/Joint.h b/src/constraint/Joint.h index 20df2911..7c9d2ed5 100644 --- a/src/constraint/Joint.h +++ b/src/constraint/Joint.h @@ -107,15 +107,19 @@ class Joint { virtual size_t getSizeInBytes() const = 0; /// Initialize before solving the joint + // TODO : REMOVE THIS virtual void initBeforeSolve(const ConstraintSolverData& constraintSolverData) = 0; /// Warm start the joint (apply the previous impulse at the beginning of the step) + // TODO : REMOVE THIS virtual void warmstart(const ConstraintSolverData& constraintSolverData) = 0; /// Solve the velocity constraint + // TODO : REMOVE THIS virtual void solveVelocityConstraint(const ConstraintSolverData& constraintSolverData) = 0; /// Solve the position constraint + // TODO : REMOVE THIS virtual void solvePositionConstraint(const ConstraintSolverData& constraintSolverData) = 0; /// Awake the two bodies of the joint diff --git a/src/engine/DynamicsWorld.cpp b/src/engine/DynamicsWorld.cpp index 7daba967..55f8d539 100644 --- a/src/engine/DynamicsWorld.cpp +++ b/src/engine/DynamicsWorld.cpp @@ -52,7 +52,8 @@ DynamicsWorld::DynamicsWorld(const Vector3& gravity, const WorldSettings& worldS mIslands(mMemoryManager.getSingleFrameAllocator()), mContactSolverSystem(mMemoryManager, mIslands, mCollisionBodyComponents, mRigidBodyComponents, mProxyShapesComponents, mConfig), - mConstraintSolverSystem(mIslands, mRigidBodyComponents, mJointsComponents), + mConstraintSolverSystem(mIslands, mRigidBodyComponents, mTransformComponents, mJointsComponents, + mBallAndSocketJointsComponents), mDynamicsSystem(mRigidBodyComponents, mTransformComponents, mIsGravityEnabled, mGravity), mNbVelocitySolverIterations(mConfig.defaultVelocitySolverNbIterations), mNbPositionSolverIterations(mConfig.defaultPositionSolverNbIterations), diff --git a/src/systems/ConstraintSolverSystem.cpp b/src/systems/ConstraintSolverSystem.cpp index 55a0e9f1..8af41d2b 100644 --- a/src/systems/ConstraintSolverSystem.cpp +++ b/src/systems/ConstraintSolverSystem.cpp @@ -26,6 +26,7 @@ // Libraries #include "systems/ConstraintSolverSystem.h" #include "components/JointComponents.h" +#include "components/BallAndSocketJointComponents.h" #include "utils/Profiler.h" #include "engine/Island.h" @@ -33,10 +34,13 @@ using namespace reactphysics3d; // Constructor ConstraintSolverSystem::ConstraintSolverSystem(Islands& islands, RigidBodyComponents& rigidBodyComponents, - JointComponents& jointComponents) + TransformComponents& transformComponents, + JointComponents& jointComponents, + BallAndSocketJointComponents& ballAndSocketJointComponents) : mIsWarmStartingActive(true), mIslands(islands), mConstraintSolverData(rigidBodyComponents, jointComponents), - mSolveBallAndSocketJointSystem(rigidBodyComponents) { + mSolveBallAndSocketJointSystem(rigidBodyComponents, transformComponents, jointComponents, ballAndSocketJointComponents), + mJointComponents(jointComponents), mBallAndSocketJointComponents(ballAndSocketJointComponents){ #ifdef IS_PROFILING_ACTIVE @@ -58,16 +62,31 @@ void ConstraintSolverSystem::initialize(decimal dt) { mConstraintSolverData.timeStep = mTimeStep; mConstraintSolverData.isWarmStartingActive = mIsWarmStartingActive; + mSolveBallAndSocketJointSystem.setTimeStep(dt); + mSolveBallAndSocketJointSystem.setIsWarmStartingActive(mIsWarmStartingActive); + + mSolveBallAndSocketJointSystem.initBeforeSolve(); + + if (mIsWarmStartingActive) { + mSolveBallAndSocketJointSystem.warmstart(); + } + // For each joint for (uint i=0; iinitBeforeSolve(mConstraintSolverData); + mJointComponents.mJoints[i]->initBeforeSolve(mConstraintSolverData); // Warm-start the constraint if warm-starting is enabled if (mIsWarmStartingActive) { @@ -81,9 +100,17 @@ void ConstraintSolverSystem::solveVelocityConstraints() { RP3D_PROFILE("ConstraintSolverSystem::solveVelocityConstraints()", mProfiler); + mSolveBallAndSocketJointSystem.solveVelocityConstraint(); + // For each joint for (uint i=0; isolveVelocityConstraint(mConstraintSolverData); } @@ -94,9 +121,17 @@ void ConstraintSolverSystem::solvePositionConstraints() { RP3D_PROFILE("ConstraintSolverSystem::solvePositionConstraints()", mProfiler); + mSolveBallAndSocketJointSystem.solvePositionConstraint(); + // For each joint for (uint i=0; isolvePositionConstraint(mConstraintSolverData); } diff --git a/src/systems/ConstraintSolverSystem.h b/src/systems/ConstraintSolverSystem.h index 0d1dbfd9..b37b7c28 100644 --- a/src/systems/ConstraintSolverSystem.h +++ b/src/systems/ConstraintSolverSystem.h @@ -54,18 +54,18 @@ struct ConstraintSolverData { /// Current time step of the simulation decimal timeStep; + /// True if warm starting of the solver is active + bool isWarmStartingActive; + /// Reference to the rigid body components RigidBodyComponents& rigidBodyComponents; /// Reference to the joint components JointComponents& jointComponents; - /// True if warm starting of the solver is active - bool isWarmStartingActive; - /// Constructor ConstraintSolverData(RigidBodyComponents& rigidBodyComponents, JointComponents& jointComponents) - :rigidBodyComponents(rigidBodyComponents), jointComponents(jointComponents) { + :rigidBodyComponents(rigidBodyComponents), jointComponents(jointComponents) { } @@ -161,6 +161,12 @@ class ConstraintSolverSystem { /// Solver for the BallAndSocketJoint constraints SolveBallAndSocketJointSystem mSolveBallAndSocketJointSystem; + // TODO : Delete this + JointComponents& mJointComponents; + + // TODO : Delete this + BallAndSocketJointComponents& mBallAndSocketJointComponents; + #ifdef IS_PROFILING_ACTIVE /// Pointer to the profiler @@ -173,7 +179,9 @@ class ConstraintSolverSystem { /// Constructor ConstraintSolverSystem(Islands& islands, RigidBodyComponents& rigidBodyComponents, - JointComponents& jointComponents); + TransformComponents& transformComponents, + JointComponents& jointComponents, + BallAndSocketJointComponents& ballAndSocketJointComponents); /// Destructor ~ConstraintSolverSystem() = default; diff --git a/src/systems/SolveBallAndSocketJointSystem.cpp b/src/systems/SolveBallAndSocketJointSystem.cpp index 76166cf2..3930e07c 100644 --- a/src/systems/SolveBallAndSocketJointSystem.cpp +++ b/src/systems/SolveBallAndSocketJointSystem.cpp @@ -25,11 +25,367 @@ // Libraries #include "systems/SolveBallAndSocketJointSystem.h" +#include "body/RigidBody.h" using namespace reactphysics3d; +// Static variables definition +const decimal SolveBallAndSocketJointSystem::BETA = decimal(0.2); + // Constructor -SolveBallAndSocketJointSystem::SolveBallAndSocketJointSystem(RigidBodyComponents& rigidBodyComponents) - :mRigidBodyComponents(rigidBodyComponents) { +SolveBallAndSocketJointSystem::SolveBallAndSocketJointSystem(RigidBodyComponents& rigidBodyComponents, + TransformComponents& transformComponents, + JointComponents& jointComponents, + BallAndSocketJointComponents& ballAndSocketJointComponents) + :mRigidBodyComponents(rigidBodyComponents), mTransformComponents(transformComponents), + mJointComponents(jointComponents), mBallAndSocketJointComponents(ballAndSocketJointComponents), + mTimeStep(0), mIsWarmStartingActive(true) { } + +// Initialize before solving the constraint +void SolveBallAndSocketJointSystem::initBeforeSolve() { + + // For each joint + for (uint32 i=0; i < mBallAndSocketJointComponents.getNbEnabledComponents(); i++) { + + const Entity jointEntity = mBallAndSocketJointComponents.mJointEntities[i]; + + // Get the bodies entities + const Entity body1Entity = mJointComponents.getBody1Entity(jointEntity); + const Entity body2Entity = mJointComponents.getBody2Entity(jointEntity); + + // TODO : Remove this and use compoents instead of pointers to bodies + const RigidBody* body1 = static_cast(mRigidBodyComponents.getRigidBody(body1Entity)); + const RigidBody* body2 = static_cast(mRigidBodyComponents.getRigidBody(body2Entity)); + + // Get the inertia tensor of bodies + mBallAndSocketJointComponents.mI1[i] = body1->getInertiaTensorInverseWorld(); + mBallAndSocketJointComponents.mI2[i] = body2->getInertiaTensorInverseWorld(); + } + + // For each joint + for (uint32 i=0; i < mBallAndSocketJointComponents.getNbEnabledComponents(); i++) { + + const Entity jointEntity = mBallAndSocketJointComponents.mJointEntities[i]; + + // Get the bodies entities + const Entity body1Entity = mJointComponents.getBody1Entity(jointEntity); + const Entity body2Entity = mJointComponents.getBody2Entity(jointEntity); + + const Quaternion& orientationBody1 = mTransformComponents.getTransform(body1Entity).getOrientation(); + const Quaternion& orientationBody2 = mTransformComponents.getTransform(body2Entity).getOrientation(); + + // Compute the vector from body center to the anchor point in world-space + mBallAndSocketJointComponents.mR1World[i] = orientationBody1 * mBallAndSocketJointComponents.mLocalAnchorPointBody1[i]; + mBallAndSocketJointComponents.mR2World[i] = orientationBody2 * mBallAndSocketJointComponents.mLocalAnchorPointBody2[i]; + } + + // For each joint + for (uint32 i=0; i < mBallAndSocketJointComponents.getNbEnabledComponents(); i++) { + + const Entity jointEntity = mBallAndSocketJointComponents.mJointEntities[i]; + + // Compute the corresponding skew-symmetric matrices + const Vector3& r1World = mBallAndSocketJointComponents.mR1World[i]; + const Vector3& r2World = mBallAndSocketJointComponents.mR2World[i]; + Matrix3x3 skewSymmetricMatrixU1 = Matrix3x3::computeSkewSymmetricMatrixForCrossProduct(r1World); + Matrix3x3 skewSymmetricMatrixU2 = Matrix3x3::computeSkewSymmetricMatrixForCrossProduct(r2World); + + // Get the bodies entities + const Entity body1Entity = mJointComponents.getBody1Entity(jointEntity); + const Entity body2Entity = mJointComponents.getBody2Entity(jointEntity); + + const uint32 componentIndexBody1 = mRigidBodyComponents.getEntityIndex(body1Entity); + const uint32 componentIndexBody2 = mRigidBodyComponents.getEntityIndex(body2Entity); + + // Compute the matrix K=JM^-1J^t (3x3 matrix) + const decimal body1MassInverse = mRigidBodyComponents.mInverseMasses[componentIndexBody1]; + const decimal body2MassInverse = mRigidBodyComponents.mInverseMasses[componentIndexBody2]; + const decimal inverseMassBodies = body1MassInverse + body2MassInverse; + const Matrix3x3& i1 = mBallAndSocketJointComponents.mI1[i]; + const Matrix3x3& i2 = mBallAndSocketJointComponents.mI2[i]; + Matrix3x3 massMatrix = Matrix3x3(inverseMassBodies, 0, 0, + 0, inverseMassBodies, 0, + 0, 0, inverseMassBodies) + + skewSymmetricMatrixU1 * i1 * skewSymmetricMatrixU1.getTranspose() + + skewSymmetricMatrixU2 * i2 * skewSymmetricMatrixU2.getTranspose(); + + // Compute the inverse mass matrix K^-1 + mBallAndSocketJointComponents.mInverseMassMatrix[i].setToZero(); + if (mRigidBodyComponents.mBodyTypes[componentIndexBody1] == BodyType::DYNAMIC || + mRigidBodyComponents.mBodyTypes[componentIndexBody2] == BodyType::DYNAMIC) { + mBallAndSocketJointComponents.mInverseMassMatrix[i] = massMatrix.getInverse(); + } + } + + // For each joint + for (uint32 i=0; i < mBallAndSocketJointComponents.getNbEnabledComponents(); i++) { + + const Entity jointEntity = mBallAndSocketJointComponents.mJointEntities[i]; + + // Get the bodies entities + const Entity body1Entity = mJointComponents.getBody1Entity(jointEntity); + const Entity body2Entity = mJointComponents.getBody2Entity(jointEntity); + + const Vector3& r1World = mBallAndSocketJointComponents.mR1World[i]; + const Vector3& r2World = mBallAndSocketJointComponents.mR2World[i]; + + const Vector3& x1 = mRigidBodyComponents.getCenterOfMassWorld(body1Entity); + const Vector3& x2 = mRigidBodyComponents.getCenterOfMassWorld(body2Entity); + + // Compute the bias "b" of the constraint + mBallAndSocketJointComponents.mBiasVector[i].setToZero(); + if (mJointComponents.getPositionCorrectionTechnique(jointEntity) == JointsPositionCorrectionTechnique::BAUMGARTE_JOINTS) { + decimal biasFactor = (BETA / mTimeStep); + mBallAndSocketJointComponents.mBiasVector[i] = biasFactor * (x2 + r2World - x1 - r1World); + } + } + + // If warm-starting is not enabled + if (!mIsWarmStartingActive) { + + // For each joint + for (uint32 i=0; i < mBallAndSocketJointComponents.getNbEnabledComponents(); i++) { + + // Reset the accumulated impulse + mBallAndSocketJointComponents.mImpulse[i].setToZero(); + } + } +} + +// Warm start the constraint (apply the previous impulse at the beginning of the step) +void SolveBallAndSocketJointSystem::warmstart() { + + // For each joint component + for (uint32 i=0; i < mBallAndSocketJointComponents.getNbEnabledComponents(); i++) { + + const Entity jointEntity = mBallAndSocketJointComponents.mJointEntities[i]; + + const Entity body1Entity = mJointComponents.getBody1Entity(jointEntity); + const Entity body2Entity = mJointComponents.getBody2Entity(jointEntity); + + const uint32 componentIndexBody1 = mRigidBodyComponents.getEntityIndex(body1Entity); + const uint32 componentIndexBody2 = mRigidBodyComponents.getEntityIndex(body2Entity); + + // Get the velocities + Vector3& v1 = mRigidBodyComponents.mConstrainedLinearVelocities[componentIndexBody1]; + Vector3& v2 = mRigidBodyComponents.mConstrainedLinearVelocities[componentIndexBody2]; + Vector3& w1 = mRigidBodyComponents.mConstrainedAngularVelocities[componentIndexBody1]; + Vector3& w2 = mRigidBodyComponents.mConstrainedAngularVelocities[componentIndexBody2]; + + const Vector3& r1World = mBallAndSocketJointComponents.mR1World[i]; + const Vector3& r2World = mBallAndSocketJointComponents.mR2World[i]; + + const Matrix3x3& i1 = mBallAndSocketJointComponents.mI1[i]; + const Matrix3x3& i2 = mBallAndSocketJointComponents.mI2[i]; + + // Compute the impulse P=J^T * lambda for the body 1 + const Vector3 linearImpulseBody1 = -mBallAndSocketJointComponents.mImpulse[i]; + const Vector3 angularImpulseBody1 = mBallAndSocketJointComponents.mImpulse[i].cross(r1World); + + // Apply the impulse to the body 1 + v1 += mRigidBodyComponents.mInverseMasses[componentIndexBody1] * linearImpulseBody1; + w1 += i1 * angularImpulseBody1; + + // Compute the impulse P=J^T * lambda for the body 2 + const Vector3 angularImpulseBody2 = -mBallAndSocketJointComponents.mImpulse[i].cross(r2World); + + // Apply the impulse to the body to the body 2 + v2 += mRigidBodyComponents.mInverseMasses[componentIndexBody2] * mBallAndSocketJointComponents.mImpulse[i]; + w2 += i2 * angularImpulseBody2; + } +} + +// Solve the velocity constraint +void SolveBallAndSocketJointSystem::solveVelocityConstraint() { + + // For each joint component + for (uint32 i=0; i < mBallAndSocketJointComponents.getNbEnabledComponents(); i++) { + + const Entity jointEntity = mBallAndSocketJointComponents.mJointEntities[i]; + + const Entity body1Entity = mJointComponents.getBody1Entity(jointEntity); + const Entity body2Entity = mJointComponents.getBody2Entity(jointEntity); + + const uint32 componentIndexBody1 = mRigidBodyComponents.getEntityIndex(body1Entity); + const uint32 componentIndexBody2 = mRigidBodyComponents.getEntityIndex(body2Entity); + + // Get the velocities + Vector3& v1 = mRigidBodyComponents.mConstrainedLinearVelocities[componentIndexBody1]; + Vector3& v2 = mRigidBodyComponents.mConstrainedLinearVelocities[componentIndexBody2]; + Vector3& w1 = mRigidBodyComponents.mConstrainedAngularVelocities[componentIndexBody1]; + Vector3& w2 = mRigidBodyComponents.mConstrainedAngularVelocities[componentIndexBody2]; + + const Matrix3x3& i1 = mBallAndSocketJointComponents.mI1[i]; + const Matrix3x3& i2 = mBallAndSocketJointComponents.mI2[i]; + + // Compute J*v + const Vector3 Jv = v2 + w2.cross(mBallAndSocketJointComponents.mR2World[i]) - v1 - w1.cross(mBallAndSocketJointComponents.mR1World[i]); + + // Compute the Lagrange multiplier lambda + const Vector3 deltaLambda = mBallAndSocketJointComponents.mInverseMassMatrix[i] * (-Jv - mBallAndSocketJointComponents.mBiasVector[i]); + mBallAndSocketJointComponents.mImpulse[i] += deltaLambda; + + // Compute the impulse P=J^T * lambda for the body 1 + const Vector3 linearImpulseBody1 = -deltaLambda; + const Vector3 angularImpulseBody1 = deltaLambda.cross(mBallAndSocketJointComponents.mR1World[i]); + + // Apply the impulse to the body 1 + v1 += mRigidBodyComponents.mInverseMasses[componentIndexBody1] * linearImpulseBody1; + w1 += i1 * angularImpulseBody1; + + // Compute the impulse P=J^T * lambda for the body 2 + const Vector3 angularImpulseBody2 = -deltaLambda.cross(mBallAndSocketJointComponents.mR2World[i]); + + // Apply the impulse to the body 2 + v2 += mRigidBodyComponents.mInverseMasses[componentIndexBody2] * deltaLambda; + w2 += i2 * angularImpulseBody2; + } +} + +// Solve the position constraint (for position error correction) +void SolveBallAndSocketJointSystem::solvePositionConstraint() { + + // For each joint component + for (uint32 i=0; i < mBallAndSocketJointComponents.getNbEnabledComponents(); i++) { + + const Entity jointEntity = mBallAndSocketJointComponents.mJointEntities[i]; + + // If the error position correction technique is not the non-linear-gauss-seidel, we do + // do not execute this method + if (mJointComponents.getPositionCorrectionTechnique(jointEntity) != JointsPositionCorrectionTechnique::NON_LINEAR_GAUSS_SEIDEL) continue; + + const Entity body1Entity = mJointComponents.getBody1Entity(jointEntity); + const Entity body2Entity = mJointComponents.getBody2Entity(jointEntity); + + // TODO : Remove this and use compoents instead of pointers to bodies + const RigidBody* body1 = static_cast(mRigidBodyComponents.getRigidBody(body1Entity)); + const RigidBody* body2 = static_cast(mRigidBodyComponents.getRigidBody(body2Entity)); + + // Recompute the inverse inertia tensors + mBallAndSocketJointComponents.mI1[i] = body1->getInertiaTensorInverseWorld(); + mBallAndSocketJointComponents.mI2[i] = body2->getInertiaTensorInverseWorld(); + } + + // For each joint component + for (uint32 i=0; i < mBallAndSocketJointComponents.getNbEnabledComponents(); i++) { + + const Entity jointEntity = mBallAndSocketJointComponents.mJointEntities[i]; + + // If the error position correction technique is not the non-linear-gauss-seidel, we do + // do not execute this method + if (mJointComponents.getPositionCorrectionTechnique(jointEntity) != JointsPositionCorrectionTechnique::NON_LINEAR_GAUSS_SEIDEL) continue; + + const Entity body1Entity = mJointComponents.getBody1Entity(jointEntity); + const Entity body2Entity = mJointComponents.getBody2Entity(jointEntity); + + // Compute the vector from body center to the anchor point in world-space + mBallAndSocketJointComponents.mR1World[i] = mRigidBodyComponents.getConstrainedOrientation(body1Entity) * + mBallAndSocketJointComponents.mLocalAnchorPointBody1[i]; + mBallAndSocketJointComponents.mR2World[i] = mRigidBodyComponents.getConstrainedOrientation(body2Entity) * + mBallAndSocketJointComponents.mLocalAnchorPointBody2[i]; + } + + // For each joint component + for (uint32 i=0; i < mBallAndSocketJointComponents.getNbEnabledComponents(); i++) { + + const Entity jointEntity = mBallAndSocketJointComponents.mJointEntities[i]; + + // If the error position correction technique is not the non-linear-gauss-seidel, we do + // do not execute this method + if (mJointComponents.getPositionCorrectionTechnique(jointEntity) != JointsPositionCorrectionTechnique::NON_LINEAR_GAUSS_SEIDEL) continue; + + const Entity body1Entity = mJointComponents.getBody1Entity(jointEntity); + const Entity body2Entity = mJointComponents.getBody2Entity(jointEntity); + + const uint32 componentIndexBody1 = mRigidBodyComponents.getEntityIndex(body1Entity); + const uint32 componentIndexBody2 = mRigidBodyComponents.getEntityIndex(body2Entity); + + const Vector3& r1World = mBallAndSocketJointComponents.mR1World[i]; + const Vector3& r2World = mBallAndSocketJointComponents.mR2World[i]; + + // Compute the corresponding skew-symmetric matrices + Matrix3x3 skewSymmetricMatrixU1 = Matrix3x3::computeSkewSymmetricMatrixForCrossProduct(r1World); + Matrix3x3 skewSymmetricMatrixU2 = Matrix3x3::computeSkewSymmetricMatrixForCrossProduct(r2World); + + // Get the inverse mass and inverse inertia tensors of the bodies + const decimal inverseMassBody1 = mRigidBodyComponents.mInverseMasses[componentIndexBody1]; + const decimal inverseMassBody2 = mRigidBodyComponents.mInverseMasses[componentIndexBody2]; + + // Recompute the inverse mass matrix K=J^TM^-1J of of the 3 translation constraints + decimal inverseMassBodies = inverseMassBody1 + inverseMassBody2; + Matrix3x3 massMatrix = Matrix3x3(inverseMassBodies, 0, 0, + 0, inverseMassBodies, 0, + 0, 0, inverseMassBodies) + + skewSymmetricMatrixU1 * mBallAndSocketJointComponents.mI1[i] * skewSymmetricMatrixU1.getTranspose() + + skewSymmetricMatrixU2 * mBallAndSocketJointComponents.mI2[i] * skewSymmetricMatrixU2.getTranspose(); + mBallAndSocketJointComponents.mInverseMassMatrix[i].setToZero(); + if (mRigidBodyComponents.mBodyTypes[componentIndexBody1] == BodyType::DYNAMIC || + mRigidBodyComponents.mBodyTypes[componentIndexBody2] == BodyType::DYNAMIC) { + mBallAndSocketJointComponents.mInverseMassMatrix[i] = massMatrix.getInverse(); + } + } + + // For each joint component + for (uint32 i=0; i < mBallAndSocketJointComponents.getNbEnabledComponents(); i++) { + + const Entity jointEntity = mBallAndSocketJointComponents.mJointEntities[i]; + + // If the error position correction technique is not the non-linear-gauss-seidel, we do + // do not execute this method + if (mJointComponents.getPositionCorrectionTechnique(jointEntity) != JointsPositionCorrectionTechnique::NON_LINEAR_GAUSS_SEIDEL) continue; + + const Entity body1Entity = mJointComponents.getBody1Entity(jointEntity); + const Entity body2Entity = mJointComponents.getBody2Entity(jointEntity); + + const uint32 componentIndexBody1 = mRigidBodyComponents.getEntityIndex(body1Entity); + const uint32 componentIndexBody2 = mRigidBodyComponents.getEntityIndex(body2Entity); + + Vector3& x1 = mRigidBodyComponents.mConstrainedPositions[componentIndexBody1]; + Vector3& x2 = mRigidBodyComponents.mConstrainedPositions[componentIndexBody2]; + + const Vector3& r1World = mBallAndSocketJointComponents.mR1World[i]; + const Vector3& r2World = mBallAndSocketJointComponents.mR2World[i]; + + // Compute the constraint error (value of the C(x) function) + const Vector3 constraintError = (x2 + r2World - x1 - r1World); + + // Compute the Lagrange multiplier lambda + // TODO : Do not solve the system by computing the inverse each time and multiplying with the + // right-hand side vector but instead use a method to directly solve the linear system. + const Vector3 lambda = mBallAndSocketJointComponents.mInverseMassMatrix[i] * (-constraintError); + + // Compute the impulse of body 1 + const Vector3 linearImpulseBody1 = -lambda; + const Vector3 angularImpulseBody1 = lambda.cross(r1World); + + // Get the inverse mass and inverse inertia tensors of the bodies + const decimal inverseMassBody1 = mRigidBodyComponents.mInverseMasses[componentIndexBody1]; + const decimal inverseMassBody2 = mRigidBodyComponents.mInverseMasses[componentIndexBody2]; + + // Compute the pseudo velocity of body 1 + const Vector3 v1 = inverseMassBody1 * linearImpulseBody1; + const Vector3 w1 = mBallAndSocketJointComponents.mI1[i] * angularImpulseBody1; + + Quaternion& q1 = mRigidBodyComponents.mConstrainedOrientations[componentIndexBody1]; + Quaternion& q2 = mRigidBodyComponents.mConstrainedOrientations[componentIndexBody2]; + + // Update the body center of mass and orientation of body 1 + x1 += v1; + q1 += Quaternion(0, w1) * q1 * decimal(0.5); + q1.normalize(); + + // Compute the impulse of body 2 + const Vector3 angularImpulseBody2 = -lambda.cross(r2World); + + // Compute the pseudo velocity of body 2 + const Vector3 v2 = inverseMassBody2 * lambda; + const Vector3 w2 = mBallAndSocketJointComponents.mI2[i] * angularImpulseBody2; + + // Update the body position/orientation of body 2 + x2 += v2; + q2 += Quaternion(0, w2) * q2 * decimal(0.5); + q2.normalize(); + } +} diff --git a/src/systems/SolveBallAndSocketJointSystem.h b/src/systems/SolveBallAndSocketJointSystem.h index b77fee1a..7bab807c 100644 --- a/src/systems/SolveBallAndSocketJointSystem.h +++ b/src/systems/SolveBallAndSocketJointSystem.h @@ -29,6 +29,8 @@ // Libraries #include "utils/Profiler.h" #include "components/RigidBodyComponents.h" +#include "components/JointComponents.h" +#include "components/BallAndSocketJointComponents.h" #include "components/TransformComponents.h" namespace reactphysics3d { @@ -41,11 +43,31 @@ class SolveBallAndSocketJointSystem { private : + // -------------------- Constants -------------------- // + + // Beta value for the bias factor of position correction + static const decimal BETA; + // -------------------- Attributes -------------------- // /// Reference to the rigid body components RigidBodyComponents& mRigidBodyComponents; + /// Reference to transform components + TransformComponents& mTransformComponents; + + /// Reference to the joint components + JointComponents& mJointComponents; + + /// Reference to the ball-and-socket joint components + BallAndSocketJointComponents& mBallAndSocketJointComponents; + + /// Current time step of the simulation + decimal mTimeStep; + + /// True if warm starting of the solver is active + bool mIsWarmStartingActive; + #ifdef IS_PROFILING_ACTIVE /// Pointer to the profiler @@ -57,11 +79,32 @@ class SolveBallAndSocketJointSystem { // -------------------- Methods -------------------- // /// Constructor - SolveBallAndSocketJointSystem(RigidBodyComponents& rigidBodyComponents); + SolveBallAndSocketJointSystem(RigidBodyComponents& rigidBodyComponents, + TransformComponents& transformComponents, + JointComponents& jointComponents, + BallAndSocketJointComponents& ballAndSocketJointComponents); /// Destructor ~SolveBallAndSocketJointSystem() = default; + /// Initialize before solving the constraint + void initBeforeSolve(); + + /// Warm start the constraint (apply the previous impulse at the beginning of the step) + void warmstart(); + + /// Solve the velocity constraint + void solveVelocityConstraint(); + + /// Solve the position constraint (for position error correction) + void solvePositionConstraint(); + + /// Set the time step + void setTimeStep(decimal timeStep); + + /// Set to true to enable warm starting + void setIsWarmStartingActive(bool isWarmStartingActive); + #ifdef IS_PROFILING_ACTIVE /// Set the profiler @@ -78,6 +121,17 @@ inline void SolveBallAndSocketJointSystem::setProfiler(Profiler* profiler) { mProfiler = profiler; } +// Set the time step +inline void SolveBallAndSocketJointSystem::setTimeStep(decimal timeStep) { + assert(timeStep > decimal(0.0)); + mTimeStep = timeStep; +} + +// Set to true to enable warm starting +inline void SolveBallAndSocketJointSystem::setIsWarmStartingActive(bool isWarmStartingActive) { + mIsWarmStartingActive = isWarmStartingActive; +} + #endif }