From 64bb1ae8d176ca8661bd3f76518e97ec4f863506 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Mar 2023 11:10:28 -0800 Subject: [PATCH] add a sign function, for lion --- csrc/kernels.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a871a55..76a8c73 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -217,6 +217,14 @@ __device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float * } } +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template +__device__ int sgn(T val) { + return (T(0) < val) - (val < T(0)); +} + __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) { const int tid = threadIdx.x + (blockDim.x*blockIdx.x);