|
@@ -10,6 +10,12 @@
|
|
|
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
|
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
|
|
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
|
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
|
|
|
|
|
|
|
|
|
+// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
|
|
|
|
|
+// by the VKQ accumulators is effectively being shifted up by a factor of 8.
|
|
|
|
|
+// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
|
|
|
|
|
+// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
|
|
|
|
|
+#define FATTN_KQ_MAX_OFFSET 0.6931f
|
|
|
|
|
+
|
|
|
typedef void (* fattn_kernel_t)(
|
|
typedef void (* fattn_kernel_t)(
|
|
|
const char * __restrict__ Q,
|
|
const char * __restrict__ Q,
|
|
|
const char * __restrict__ K,
|
|
const char * __restrict__ K,
|