|
@@ -45,6 +45,18 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
|
|
|
|
|
// Skip unused kernel variants for faster compilation:
|
|
// Skip unused kernel variants for faster compilation:
|
|
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
|
|
|
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
|
|
|
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
|
|
|
|
+ GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
|
|
|
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
|
|
|
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
|
|
|
+ GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
|
|
|
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
|
|
|
+ GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
|
|
|
+ GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
|
|
|
+ GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
|
|
|
+ GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
|
|
|
+ GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
|
NO_DEVICE_CODE;
|
|
NO_DEVICE_CODE;
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
@@ -114,7 +126,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
// Set memory to zero if out of bounds:
|
|
// Set memory to zero if out of bounds:
|
|
|
if (ncols > 2 && ic0 + j >= ne01) {
|
|
if (ncols > 2 && ic0 + j >= ne01) {
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
|
|
|
|
|
|
+ for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
|
|
const int i = i0 + threadIdx.x;
|
|
const int i = i0 + threadIdx.x;
|
|
|
|
|
|
|
|
tmp_q_i32[i] = 0;
|
|
tmp_q_i32[i] = 0;
|
|
@@ -127,7 +139,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
|
|
|
|
|
const float * Q_f = (const float *) (Q + j*nb01);
|
|
const float * Q_f = (const float *) (Q + j*nb01);
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
|
|
|
|
|
|
+ for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
|
|
quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
|
|
quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -140,7 +152,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
|
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
|
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
|
|
|
|
|
|
+ for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
|
|
const int i = i0 + threadIdx.x;
|
|
const int i = i0 + threadIdx.x;
|
|
|
|
|
|
|
|
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
|
|
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
|