|
|
@@ -2776,10 +2776,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|
|
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
|
|
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
|
{
|
|
|
- if (op->src[4]) {
|
|
|
- return false;
|
|
|
- }
|
|
|
-
|
|
|
const ggml_tensor * q = op->src[0];
|
|
|
const ggml_tensor * k = op->src[1];
|
|
|
const ggml_tensor * v = op->src[2];
|
|
|
@@ -5765,6 +5761,7 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
|
|
|
static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
|
|
|
const ggml_tensor * v = dst->src[2];
|
|
|
const ggml_tensor * mask = dst->src[3];
|
|
|
+ const ggml_tensor * sinks = dst->src[4];
|
|
|
GGML_ASSERT(q->extra);
|
|
|
GGML_ASSERT(k->extra);
|
|
|
GGML_ASSERT(v->extra);
|
|
|
@@ -5772,6 +5769,9 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
|
|
|
if (mask) {
|
|
|
GGML_ASSERT(mask->extra);
|
|
|
}
|
|
|
+ if (sinks) {
|
|
|
+ GGML_ASSERT(sinks->extra);
|
|
|
+ }
|
|
|
|
|
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
|
|
|
|
|
@@ -5813,6 +5813,7 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
|
|
|
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
|
|
|
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
|
|
|
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
|
|
|
+ ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL;
|
|
|
|
|
|
cl_ulong offset_q = extra_q->offset + q->view_offs;
|
|
|
cl_ulong offset_k = extra_k->offset + k->view_offs;
|
|
|
@@ -5820,6 +5821,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
|
|
|
cl_ulong offset_o = extra_o->offset + dst->view_offs;
|
|
|
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
|
|
|
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
|
|
|
+ cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL;
|
|
|
+ cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0;
|
|
|
|
|
|
const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
|
|
|
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
|
|
|
@@ -5874,6 +5877,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
|
|
|
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
|
|
|
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
|
|
|
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer));
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks));
|
|
|
|
|
|
if (n_q == 1) {
|
|
|
const size_t wg_size = 64;
|