fattn-common.cuh 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. #define FATTN_KQ_STRIDE 256
  2. #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
  3. #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
  4. template<int D, int parallel_blocks> // D == head size
  5. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  6. __launch_bounds__(D, 1)
  7. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  8. static __global__ void flash_attn_combine_results(
  9. const float * __restrict__ VKQ_parts,
  10. const float2 * __restrict__ VKQ_meta,
  11. float * __restrict__ dst) {
  12. VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
  13. VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
  14. dst += D * gridDim.y*blockIdx.x;
  15. const int tid = threadIdx.x;
  16. __builtin_assume(tid < D);
  17. __shared__ float2 meta[parallel_blocks];
  18. if (tid < 2*parallel_blocks) {
  19. ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
  20. }
  21. __syncthreads();
  22. float kqmax = meta[0].x;
  23. #pragma unroll
  24. for (int l = 1; l < parallel_blocks; ++l) {
  25. kqmax = max(kqmax, meta[l].x);
  26. }
  27. float VKQ_numerator = 0.0f;
  28. float VKQ_denominator = 0.0f;
  29. #pragma unroll
  30. for (int l = 0; l < parallel_blocks; ++l) {
  31. const float diff = meta[l].x - kqmax;
  32. const float KQ_max_scale = expf(diff);
  33. const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
  34. *((uint32_t *) &KQ_max_scale) &= ftz_mask;
  35. VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
  36. VKQ_denominator += KQ_max_scale * meta[l].y;
  37. }
  38. dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
  39. }