|
@@ -7909,10 +7909,10 @@ void ggml_compute_forward_argsort(
|
|
|
|
|
|
|
|
// ggml_compute_forward_flash_attn_ext
|
|
// ggml_compute_forward_flash_attn_ext
|
|
|
|
|
|
|
|
-static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
|
|
|
|
+static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
const ggml_compute_params * params,
|
|
const ggml_compute_params * params,
|
|
|
- ggml_tensor * dst) {
|
|
|
|
|
-
|
|
|
|
|
|
|
+ ggml_tensor * dst,
|
|
|
|
|
+ int ir0, int ir1) {
|
|
|
const ggml_tensor * q = dst->src[0];
|
|
const ggml_tensor * q = dst->src[0];
|
|
|
const ggml_tensor * k = dst->src[1];
|
|
const ggml_tensor * k = dst->src[1];
|
|
|
const ggml_tensor * v = dst->src[2];
|
|
const ggml_tensor * v = dst->src[2];
|
|
@@ -7928,9 +7928,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
|
|
|
|
|
|
- const int ith = params->ith;
|
|
|
|
|
- const int nth = params->nth;
|
|
|
|
|
-
|
|
|
|
|
const int64_t DK = nek0;
|
|
const int64_t DK = nek0;
|
|
|
const int64_t DV = nev0;
|
|
const int64_t DV = nev0;
|
|
|
const int64_t N = neq1;
|
|
const int64_t N = neq1;
|
|
@@ -7964,16 +7961,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
|
|
|
|
|
// parallelize by q rows using ggml_vec_dot_f32
|
|
// parallelize by q rows using ggml_vec_dot_f32
|
|
|
|
|
|
|
|
- // total rows in q
|
|
|
|
|
- const int nr = neq1*neq2*neq3;
|
|
|
|
|
-
|
|
|
|
|
- // rows per thread
|
|
|
|
|
- const int dr = (nr + nth - 1)/nth;
|
|
|
|
|
-
|
|
|
|
|
- // row range for this thread
|
|
|
|
|
- const int ir0 = dr*ith;
|
|
|
|
|
- const int ir1 = MIN(ir0 + dr, nr);
|
|
|
|
|
-
|
|
|
|
|
float scale = 1.0f;
|
|
float scale = 1.0f;
|
|
|
float max_bias = 0.0f;
|
|
float max_bias = 0.0f;
|
|
|
float logit_softcap = 0.0f;
|
|
float logit_softcap = 0.0f;
|
|
@@ -8000,6 +7987,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
|
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
|
|
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
|
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
|
|
|
|
|
|
|
|
|
+ int ith = params->ith;
|
|
|
|
|
+
|
|
|
// loop over n_batch and n_head
|
|
// loop over n_batch and n_head
|
|
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
|
// q indices
|
|
// q indices
|
|
@@ -8147,6 +8136,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
|
|
+ const ggml_compute_params * params,
|
|
|
|
|
+ ggml_tensor * dst) {
|
|
|
|
|
+
|
|
|
|
|
+ const ggml_tensor * q = dst->src[0];
|
|
|
|
|
+ const ggml_tensor * k = dst->src[1];
|
|
|
|
|
+ const ggml_tensor * v = dst->src[2];
|
|
|
|
|
+
|
|
|
|
|
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
|
|
|
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
|
|
|
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
|
|
|
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
|
|
|
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
|
|
|
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
|
|
|
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
|
|
|
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t DK = nek0;
|
|
|
|
|
+ const int64_t DV = nev0;
|
|
|
|
|
+ const int64_t N = neq1;
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(ne0 == DV);
|
|
|
|
|
+ GGML_ASSERT(ne2 == N);
|
|
|
|
|
+
|
|
|
|
|
+ // input tensor rows must be contiguous
|
|
|
|
|
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
|
|
|
|
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
|
|
|
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(neq0 == DK);
|
|
|
|
|
+ GGML_ASSERT(nek0 == DK);
|
|
|
|
|
+ GGML_ASSERT(nev0 == DV);
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(neq1 == N);
|
|
|
|
|
+
|
|
|
|
|
+ // dst cannot be transposed or permuted
|
|
|
|
|
+ GGML_ASSERT(nb0 == sizeof(float));
|
|
|
|
|
+ GGML_ASSERT(nb0 <= nb1);
|
|
|
|
|
+ GGML_ASSERT(nb1 <= nb2);
|
|
|
|
|
+ GGML_ASSERT(nb2 <= nb3);
|
|
|
|
|
+
|
|
|
|
|
+ // parallelize by q rows using ggml_vec_dot_f32
|
|
|
|
|
+
|
|
|
|
|
+ // total rows in q
|
|
|
|
|
+ const int64_t nr = neq1*neq2*neq3;
|
|
|
|
|
+
|
|
|
|
|
+ // rows per thread
|
|
|
|
|
+ const int ith = params->ith;
|
|
|
|
|
+ const int nth = params->nth;
|
|
|
|
|
+
|
|
|
|
|
+ // disable for NUMA
|
|
|
|
|
+ const bool disable_chunking = ggml_is_numa();
|
|
|
|
|
+
|
|
|
|
|
+ // 4x chunks per thread
|
|
|
|
|
+ int nth_scaled = nth * 4;
|
|
|
|
|
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
|
|
|
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
|
|
|
+
|
|
|
|
|
+ if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
|
|
|
+ nchunk = nth;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (ith == 0) {
|
|
|
|
|
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
|
|
|
+ ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ ggml_barrier(params->threadpool);
|
|
|
|
|
+
|
|
|
|
|
+ // The number of elements in each chunk
|
|
|
|
|
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
|
|
|
+
|
|
|
|
|
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
|
|
|
|
|
+ int current_chunk = ith;
|
|
|
|
|
+
|
|
|
|
|
+ while (current_chunk < nchunk) {
|
|
|
|
|
+ const int64_t ir0 = dr * current_chunk;
|
|
|
|
|
+ const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
|
|
|
+
|
|
|
|
|
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
|
|
|
|
+
|
|
|
|
|
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
void ggml_compute_forward_flash_attn_ext(
|
|
void ggml_compute_forward_flash_attn_ext(
|
|
|
const ggml_compute_params * params,
|
|
const ggml_compute_params * params,
|
|
|
ggml_tensor * dst) {
|
|
ggml_tensor * dst) {
|