|
@@ -6,7 +6,7 @@
|
|
|
|
|
|
|
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
|
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
-__launch_bounds__(nwarps*WARP_SIZE, 1)
|
|
|
|
|
|
|
+__launch_bounds__(nwarps*WARP_SIZE, 2)
|
|
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
static __global__ void flash_attn_tile_ext_f16(
|
|
static __global__ void flash_attn_tile_ext_f16(
|
|
|
const char * __restrict__ Q,
|
|
const char * __restrict__ Q,
|
|
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
const int ne12,
|
|
const int ne12,
|
|
|
const int ne13,
|
|
const int ne13,
|
|
|
const int ne31,
|
|
const int ne31,
|
|
|
|
|
+ const int ne32,
|
|
|
const int nb31,
|
|
const int nb31,
|
|
|
|
|
+ const int nb32,
|
|
|
const int nb01,
|
|
const int nb01,
|
|
|
const int nb02,
|
|
const int nb02,
|
|
|
const int nb03,
|
|
const int nb03,
|
|
@@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
|
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
|
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
|
|
- const half * maskh = (const half *) mask + ne11*ic0;
|
|
|
|
|
|
|
+ const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
|
|
|
|
|
|
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
|
|
|
|
|
@@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
|
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
|
|
|
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
|
|
|
|
|
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
|
|
|
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|