forked from mrq/bitsandbytes-rocm
342 lines
9.9 KiB
C++
342 lines
9.9 KiB
C++
#pragma once
|
|
|
|
#include <algorithm>
|
|
#include <limits>
|
|
#include <type_traits>
|
|
#include "AAlloc.h"
|
|
|
|
namespace BinSearch {
|
|
namespace Details {
|
|
|
|
namespace DirectAux {
|
|
|
|
#define SAFETY_MULTI_PASS true
|
|
|
|
template <typename T>
|
|
struct HResults
|
|
{
|
|
HResults(T h, double ratio, size_t n) : H(h), hRatio(ratio), nInc(n) {}
|
|
T H;
|
|
double hRatio;
|
|
size_t nInc;
|
|
};
|
|
|
|
|
|
#ifdef USE_FMA
|
|
template <Algos A> struct IsDirect { static const bool value = (A == Direct) || (A == DirectFMA); };
|
|
template <Algos A> struct IsDirect2 { static const bool value = (A == Direct2) || (A == Direct2FMA); };
|
|
template <Algos A> struct IsDirectCache { static const bool value = (A == DirectCache) || (A == DirectCacheFMA); };
|
|
#else
|
|
template <Algos A> struct IsDirect { static const bool value = (A == Direct); };
|
|
template <Algos A> struct IsDirect2 { static const bool value = (A == Direct2); };
|
|
template <Algos A> struct IsDirectCache { static const bool value = (A == DirectCache); };
|
|
#endif
|
|
|
|
// general definition
|
|
template <Algos A, typename T, typename Enable = void>
|
|
struct BucketElem
|
|
{
|
|
FORCE_INLINE void set( uint32 b, const T *)
|
|
{
|
|
m_b = b;
|
|
}
|
|
|
|
FORCE_INLINE uint32 index() const { return m_b; }
|
|
|
|
private:
|
|
uint32 m_b;
|
|
};
|
|
|
|
// specialization for DirectCache methods
|
|
|
|
template <typename T> struct MatchingIntType;
|
|
template <> struct MatchingIntType<double> { typedef uint64 type; };
|
|
template <> struct MatchingIntType<float> { typedef uint32 type; };
|
|
|
|
template <Algos A, typename T>
|
|
struct BucketElem<A, T, typename std::enable_if< IsDirectCache<A>::value >::type >
|
|
{
|
|
typedef typename MatchingIntType<T>::type I;
|
|
|
|
void set(uint32 b, const T *xi)
|
|
{
|
|
u.u.x = xi[b];
|
|
u.u.b = b;
|
|
}
|
|
|
|
FORCE_INLINE I index() const { return u.u.b; }
|
|
FORCE_INLINE T x() const { return u.u.x; }
|
|
|
|
private:
|
|
union {
|
|
double dummy;
|
|
struct
|
|
{
|
|
T x;
|
|
I b;
|
|
} u;
|
|
} u;
|
|
};
|
|
|
|
|
|
template <bool UseFMA, unsigned char Gap, typename T>
|
|
struct DirectTraits
|
|
{
|
|
static void checkH(T scaler, T x0, T xN)
|
|
{
|
|
T Dn = xN - x0;
|
|
T ifmax = Dn * scaler;
|
|
myassert((ifmax < std::numeric_limits<uint32>::max() - (Gap - 1)),
|
|
"Problem unfeasible: index size exceeds uint32 capacity:"
|
|
<< " D[N] =" << Dn
|
|
<< ", H =" << scaler
|
|
<< ", H D[n] =" << ifmax << "\n"
|
|
);
|
|
}
|
|
|
|
FORCE_INLINE static uint32 f(T scaler, T x0, T z)
|
|
{
|
|
T tmp = scaler * (z - x0);
|
|
#ifdef USE_SSE2
|
|
return ftoi(FVec1<SSE,T>(tmp));
|
|
#else
|
|
return static_cast<uint32>(tmp);
|
|
#endif
|
|
}
|
|
|
|
template <InstrSet I>
|
|
FORCE_INLINE static typename FTOITraits<I, T>::vec_t f(const FVec<I, T>& scaler, const FVec<I, T>& x0, const FVec<I, T>& z)
|
|
{
|
|
return ftoi(scaler*(z-x0));
|
|
}
|
|
|
|
static T cst0(T scaler, T x0)
|
|
{
|
|
return x0;
|
|
}
|
|
};
|
|
|
|
#ifdef USE_FMA
|
|
template <unsigned char Gap, typename T>
|
|
struct DirectTraits<true,Gap,T>
|
|
{
|
|
typedef FVec1<SSE, T> fVec1;
|
|
|
|
static void checkH(T scaler, T H_Times_x0, T xN)
|
|
{
|
|
union {
|
|
typename FVec1<SSE, T>::vec_t v;
|
|
T s;
|
|
} ifmax;
|
|
ifmax.v = mulSub(fVec1(scaler), fVec1(xN), fVec1(H_Times_x0));
|
|
myassert((ifmax.s < std::numeric_limits<uint32>::max() - (Gap - 1)),
|
|
"Problem unfeasible: index size exceeds uint32 capacity:"
|
|
<< " H X[0] =" << H_Times_x0
|
|
<< ", H =" << scaler
|
|
<< ", X[N] =" << xN
|
|
<< ", H X[N] - H X[0] =" << ifmax.s << "\n"
|
|
);
|
|
}
|
|
|
|
FORCE_INLINE static uint32 f(T scaler, T Hx0, T xi)
|
|
{
|
|
return ftoi(mulSub(fVec1(scaler), fVec1(xi), fVec1(Hx0)));
|
|
}
|
|
|
|
template <InstrSet I>
|
|
FORCE_INLINE static typename FTOITraits<I,T>::vec_t f(const FVec<I,T>& scaler, const FVec<I, T>& H_Times_X0, const FVec<I, T>& z)
|
|
{
|
|
return ftoi(mulSub(scaler, z, H_Times_X0));
|
|
}
|
|
|
|
static T cst0(T scaler, T x0)
|
|
{
|
|
return scaler*x0;
|
|
}
|
|
};
|
|
#endif
|
|
|
|
template <unsigned char Gap, typename T, Algos A>
|
|
struct DirectInfo
|
|
{
|
|
static const bool UseFMA = (A == DirectFMA) || (A == Direct2FMA) || (A == DirectCacheFMA);
|
|
typedef DirectTraits<UseFMA, Gap, T> fun_t;
|
|
typedef BucketElem<A,T> bucket_t;
|
|
typedef AlignedVec<bucket_t> bucketvec_t;
|
|
|
|
struct Data {
|
|
Data() : buckets(0), xi(0), scaler(0), cst0(0) {}
|
|
Data( const T *x // for Direct must persist if xws=NULL
|
|
, uint32 n
|
|
, T H
|
|
, bucket_t *bws // assumed to gave size nb, as computed below
|
|
, T *xws = NULL // assumed to have size (n+Gap-1). Optional for Direct, unused for DirectCache, required for DirectGap
|
|
)
|
|
: buckets(bws)
|
|
, scaler(H)
|
|
, cst0(fun_t::cst0(H, x[0]))
|
|
{
|
|
myassert(((bws != NULL) && (isAligned(bws,64))), "bucket pointer not allocated or incorrectly aligned");
|
|
|
|
uint32 nb = 1 + fun_t::f(H, cst0, x[n-1]);
|
|
|
|
const uint32 npad = Gap-1;
|
|
const uint32 n_sz = n + npad; // size of padded vector
|
|
|
|
if (xws) {
|
|
myassert(isAligned(xws,8), "x pointer not allocated or incorrectly aligned");
|
|
std::fill_n(xws, npad, x[0]); // pad in front with x[0]
|
|
std::copy(x, x+n, xws + npad);
|
|
xi = xws;
|
|
}
|
|
else {
|
|
myassert(Gap==1, "if Gap>1 then X workspace must be provided");
|
|
xi = x;
|
|
}
|
|
|
|
populateIndex(bws, nb, xi, n_sz, scaler, cst0);
|
|
}
|
|
|
|
const bucket_t *buckets;
|
|
const T *xi;
|
|
T scaler;
|
|
T cst0; // could be x0 or (scaler*x0), depending if we are using FMA or not
|
|
} data;
|
|
|
|
static T growStep(T H)
|
|
{
|
|
T step;
|
|
T P = next(H);
|
|
while ((step = P - H) == 0)
|
|
P = next(P);
|
|
return step;
|
|
}
|
|
|
|
static HResults<T> computeH(const T *px, uint32 nx)
|
|
{
|
|
myassert((nx > Gap), "Array X too small");
|
|
myassert(((Gap == 1) || (Gap == 2)), "Only tested for these values of Gap");
|
|
|
|
const T x0 = px[0];
|
|
const T xN = px[nx-1];
|
|
|
|
const T range = xN - x0;
|
|
myassert((range < std::numeric_limits<T>::max()), "range too large");
|
|
|
|
// check that D_i are strictly increasing and compute minimum value D_{i+Offset}-D_i
|
|
T deltaDMin = range;
|
|
for (uint32 i = Gap; i < nx; ++i) {
|
|
T Dnew = px[i] - x0;
|
|
T Dold = px[i - Gap] - x0;
|
|
myassert((Dnew > Dold),
|
|
"Problem unfeasible: D_i sequence not strictly increasing"
|
|
<< " X[" << 0 << "]=" << x0
|
|
<< " X[" << i - Gap << "]=" << px[i - Gap]
|
|
<< " X[" << i << "]=" << px[i]
|
|
<< "\n"
|
|
);
|
|
T deltaD = Dnew - Dold;
|
|
if (deltaD < deltaDMin)
|
|
deltaDMin = deltaD;
|
|
}
|
|
|
|
// initial guess for H
|
|
const T H0 = T(1.0) / deltaDMin;
|
|
T H = H0;
|
|
|
|
T cst0 = fun_t::cst0(H, x0);
|
|
fun_t::checkH(H, cst0, xN);
|
|
|
|
// adjust H by trial and error until succeed
|
|
size_t nInc = 0;
|
|
bool modified = false;
|
|
size_t npasses = 0;
|
|
T step = growStep(H);
|
|
uint32 seg_already_checked_from = nx;
|
|
do {
|
|
myassert((npasses++ < 2), "verification failed\n");
|
|
// if there has been an increase, then check only up to that point
|
|
uint32 last_seg_to_be_checked = seg_already_checked_from - 1;
|
|
modified = false;
|
|
uint32 inew = 0;
|
|
for (uint32 i = Gap; i <= last_seg_to_be_checked; ++i) {
|
|
uint32 iold = fun_t::f(H, cst0, px[i-Gap]);
|
|
uint32 inew = fun_t::f(H, cst0, px[i]);
|
|
while (inew == iold) {
|
|
seg_already_checked_from = i;
|
|
last_seg_to_be_checked = nx-1; // everything needs to be checked
|
|
modified = true;
|
|
H = H + step;
|
|
step *= 2;
|
|
// recalculate all constants and indices
|
|
cst0 = fun_t::cst0(H, x0);
|
|
fun_t::checkH(H, cst0, xN);
|
|
iold = fun_t::f(H, cst0, px[i - Gap]);
|
|
inew = fun_t::f(H, cst0, px[i]);
|
|
}
|
|
}
|
|
} while (SAFETY_MULTI_PASS && modified);
|
|
|
|
return HResults<T>(H, (((double)H) / H0) - 1.0, nInc);
|
|
}
|
|
|
|
static void populateIndex(BucketElem<A, T> *buckets, uint32 index_size, const T *px, uint32 x_size, T scaler, T cst0)
|
|
{
|
|
for (uint32 i = x_size-1, b = index_size-1, j=0; ; --i) {
|
|
uint32 idx = fun_t::f(scaler, cst0, px[i]);
|
|
while (b > idx) { // in the 1st iteration it is j=0 but this condition is always false
|
|
buckets[b].set( j, px );
|
|
--b;
|
|
}
|
|
if (Gap==1 || b == idx) { // if Gap==1, which is known at compile time, the check b==idx is redundant
|
|
j = i - (Gap-1); // subtracting (Gap-1) points to the index of the first X-element to check
|
|
buckets[b].set(j, px);
|
|
if (b-- == 0)
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
DirectInfo(const Data& d)
|
|
: data(d)
|
|
{
|
|
}
|
|
|
|
DirectInfo(const T* px, const uint32 n)
|
|
{
|
|
HResults<T> res = computeH(px, n);
|
|
|
|
#ifdef PAPER_TEST
|
|
nInc = res.nInc;
|
|
hRatio = res.hRatio;
|
|
#endif
|
|
const uint32 npad = Gap-1;
|
|
const uint32 n_sz = n + npad; // size of padded vector
|
|
|
|
if (npad)
|
|
xi.resize(n_sz);
|
|
|
|
T H = res.H;
|
|
T cst0 = fun_t::cst0(H, px[0]);
|
|
const uint32 maxIndex = fun_t::f(H, cst0, px[n-1]);
|
|
buckets.resize(maxIndex + 1);
|
|
|
|
data = Data(px, n, H, buckets.begin(), (npad? xi.begin(): NULL));
|
|
}
|
|
|
|
private:
|
|
bucketvec_t buckets;
|
|
AlignedVec<T,8> xi;
|
|
|
|
#ifdef PAPER_TEST
|
|
public:
|
|
double hRatio;
|
|
size_t nInc;
|
|
#endif
|
|
};
|
|
|
|
|
|
} // namespace DirectAux
|
|
} // namespace Details
|
|
} // namespace BinSearch
|